Example #1
0
    def fit(self, data_inst, validate_data=None):

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = HomoLabelEncoderArbiter().label_alignment()
            LOGGER.info('label mapping is {}'.format(label_mapping))
            self.booster_dim = len(
                label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        LOGGER.info('begin to fit a boosting tree')
        for epoch_idx in range(self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            for class_idx in range(self.booster_dim):
                model = self.fit_a_booster(epoch_idx, class_idx)

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.history_loss.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))

        self.set_summary(self.generate_summary())
Example #2
0
def server_fit(self, data_inst):
    if not self.component_properties.is_warm_start:
        label_mapping = HomoLabelEncoderArbiter().label_alignment()
        LOGGER.info(f"label mapping: {label_mapping}")
    else:
        self.callback_warm_start_init_iter(self.aggregate_iteration_num + 1)
    while self.aggregate_iteration_num + 1 < self.max_aggregate_iteration_num:
        # update iteration num
        self.aggregate_iteration_num += 1

        self.callback_list.on_epoch_begin(self.aggregate_iteration_num)
        self.model = self.aggregator.weighted_mean_model(suffix=_suffix(self))
        self.aggregator.send_aggregated_model(model=self.model,
                                              suffix=_suffix(self))
        self.callback_list.on_epoch_end(self.aggregate_iteration_num)
        if server_is_converged(self):
            LOGGER.info(f"early stop at iter {self.aggregate_iteration_num}")
            break
    else:
        LOGGER.warn(
            f"reach max iter: {self.aggregate_iteration_num}, not converged")
    self.set_summary(self._summary)
Example #3
0
 def dataset_align():
     LOGGER.info("start label alignment")
     label_mapping = HomoLabelEncoderArbiter().label_alignment()
     LOGGER.info(f"label aligned, mapping: {label_mapping}")
Example #4
0
 def _server_check_data(self):
     HomoLabelEncoderArbiter().label_alignment()