Example #1
0
    def check_convergence(self, loss):
        """
        check if the loss converges
        """
        LOGGER.info("check convergence")
        if self.convergence is None:
            self.convergence = converge_func_factory("diff", self.tol)

        return self.convergence.is_converge(loss)
Example #2
0
    def fit(self, data_inst, validate_data=None):

        # init binning obj
        self.aggregator = HomoBoostArbiterAggregator()
        self.binning_obj = HomoFeatureBinningServer()

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = HomoLabelEncoderArbiter().label_alignment()
            LOGGER.info('label mapping is {}'.format(label_mapping))
            self.booster_dim = len(
                label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        # sync start round and end round
        self.sync_start_round_and_end_round()

        LOGGER.info('begin to fit a boosting tree')
        self.preprocess()
        for epoch_idx in range(self.start_round, self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            for class_idx in range(self.booster_dim):
                model = self.fit_a_learner(epoch_idx, class_idx)

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.history_loss.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))
        self.postprocess()
        self.callback_list.on_train_end()
        self.set_summary(self.generate_summary())
Example #3
0
 def __init__(self, guest: PlainFTLGuestModel, model_param: FTLModelParam,
              transfer_variable: HeteroFTLTransferVariable):
     super(HeteroFTLGuest, self).__init__()
     self.guest_model = guest
     self.model_param = model_param
     self.transfer_variable = transfer_variable
     self.max_iter = model_param.max_iter
     self.n_iter_ = 0
     # self.converge_func = DiffConverge(eps=model_param.eps)
     self.converge_func = converge_func_factory(early_stop='diff',
                                                tol=model_param.eps)
 def _init_model(self, params):
     self.model_param = params
     self.alpha = params.alpha
     self.init_param_obj = params.init_param
     self.fit_intercept = self.init_param_obj.fit_intercept
     self.batch_size = params.batch_size
     self.max_iter = params.max_iter
     self.optimizer = optimizer_factory(params)
     self.converge_func = converge_func_factory(params.early_stop, params.tol)
     self.encrypted_calculator = None
     self.validation_freqs = params.validation_freqs
    def fit(self, data_inst, valid_inst=None):

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()
        self.tree_dim = 1

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = self.label_alignment()
            LOGGER.debug('label mapping is {}'.format(label_mapping))
            self.tree_dim = len(label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        LOGGER.debug('begin to fit a boosting tree')
        for epoch_idx in range(self.num_trees):

            for t_idx in range(self.tree_dim):
                valid_feature = self.sample_valid_feature()
                self.send_valid_features(valid_feature, epoch_idx, t_idx)
                flow_id = self.generate_flowid(epoch_idx, t_idx)
                new_tree = HomoDecisionTreeArbiter(self.tree_param,
                                                   valid_feature=valid_feature,
                                                   epoch_idx=epoch_idx,
                                                   flow_id=flow_id,
                                                   tree_idx=t_idx)
                new_tree.fit()

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.global_loss_history.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.global_loss_history)}))

        LOGGER.debug('fitting h**o decision tree done')
Example #6
0
    def test_abs_converge(self):
        loss = 50
        eps = 0.00001
        # converge_func = AbsConverge(eps=eps)
        converge_func = converge_func_factory(early_stop='abs', tol=eps)

        iter_num = 0
        while iter_num < 500:
            loss *= 0.5
            converge_flag = converge_func.is_converge(loss)
            if converge_flag:
                break
            iter_num += 1
        self.assertTrue(math.fabs(loss) <= eps)
Example #7
0
 def test_diff_converge(self):
     loss = 50
     eps = 0.00001
     # converge_func = DiffConverge(eps=eps)
     converge_func = converge_func_factory(early_stop='diff', tol=eps)
     iter_num = 0
     pre_loss = loss
     while iter_num < 500:
         loss *= 0.5
         converge_flag = converge_func.is_converge(loss)
         if converge_flag:
             break
         iter_num += 1
         pre_loss = loss
     self.assertTrue(math.fabs(pre_loss - loss) <= eps)
Example #8
0
 def _init_model(self, params):
     self.model_param = params
     self.alpha = params.alpha
     self.init_param_obj = params.init_param
     # self.fit_intercept = self.init_param_obj.fit_intercept
     self.batch_size = params.batch_size
     self.max_iter = params.max_iter
     self.optimizer = optimizer_factory(params)
     self.converge_func = converge_func_factory(params.early_stop, params.tol)
     self.encrypted_calculator = None
     self.validation_freqs = params.validation_freqs
     self.validation_strategy = None
     self.early_stopping_rounds = params.early_stopping_rounds
     self.metrics = params.metrics
     self.use_first_metric_only = params.use_first_metric_only
Example #9
0
def server_init_model(self, param):
    self.aggregate_iteration_num = 0
    self.aggregator = secure_mean_aggregator.Server(
        self.transfer_variable.secure_aggregator_trans_var)
    self.loss_scatter = loss_scatter.Server(
        self.transfer_variable.loss_scatter_trans_var)
    self.has_converged = has_converged.Server(
        self.transfer_variable.has_converged_trans_var)

    self._summary = dict(loss_history=[], is_converged=False)

    self.param = param
    self.enable_secure_aggregate = param.secure_aggregate
    self.max_aggregate_iteration_num = param.max_iter
    early_stop = self.model_param.early_stop
    self.converge_func = converge_func_factory(early_stop.converge_func,
                                               early_stop.eps).is_converge
    self.loss_consumed = early_stop.converge_func != "weight_diff"
Example #10
0
    def _init_model(self, params):
        self.model_param = params
        self.alpha = params.alpha
        self.init_param_obj = params.init_param
        # self.fit_intercept = self.init_param_obj.fit_intercept
        self.batch_size = params.batch_size

        if hasattr(params, "shuffle"):
            self.shuffle = params.shuffle
        if hasattr(params, "masked_rate"):
            self.masked_rate = params.masked_rate
        if hasattr(params, "batch_strategy"):
            self.batch_strategy = params.batch_strategy

        self.max_iter = params.max_iter
        self.optimizer = optimizer_factory(params)
        self.converge_func = converge_func_factory(params.early_stop, params.tol)
        self.validation_freqs = params.callback_param.validation_freqs
        self.validation_strategy = None
        self.early_stopping_rounds = params.callback_param.early_stopping_rounds
        self.metrics = params.callback_param.metrics
        self.use_first_metric_only = params.callback_param.use_first_metric_only
Example #11
0
 def _init_model(self, param):
     super(HomoNNArbiter, self)._init_model(param)
     early_stop = self.model_param.early_stop
     self.converge_func = converge_func_factory(early_stop.converge_func, early_stop.eps).is_converge
     self.loss_consumed = early_stop.converge_func != "weight_diff"
Example #12
0
    def _init_model(self, hetero_nn_param):
        super(HeteroNNGuest, self)._init_model(hetero_nn_param)

        self.task_type = hetero_nn_param.task_type
        self.converge_func = converge_func_factory(self.early_stop, self.tol)
    def check_convergence(self, loss):
        LOGGER.info("check convergence")
        if self.convegence is None:
            self.convegence = converge_func_factory(params.converge_func, params.eps)

        return self.convegence.is_converge(loss)