Ejemplo n.º 1
0
 def callback(self):
     meta_info = {"intersect_method": self.model_param.intersect_method,
                  "join_method": self.model_param.join_method}
     self.callback_metric(metric_name=self.metric_name,
                          metric_namespace=self.metric_namespace,
                          metric_data=[Metric("intersect_count", self.intersect_num),
                                       Metric("intersect_rate", self.intersect_rate),
                                       Metric("unmatched_count", self.unmatched_num),
                                       Metric("unmatched_rate", self.unmatched_rate)])
     self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                  metric_name=self.metric_name,
                                  metric_meta=MetricMeta(name=self.metric_name,
                                                         metric_type=self.metric_type,
                                                         extra_metas=meta_info)
                                  )
Ejemplo n.º 2
0
    def callback_info(self):
        class_weight = None
        classes = None
        if self.class_weight_dict:
            class_weight = {
                str(k): v
                for k, v in self.class_weight_dict.items()
            }
            classes = sorted([str(k) for k in self.class_weight_dict.keys()])
        # LOGGER.debug(f"callback class weight is: {class_weight}")

        metric_meta = MetricMeta(name='train',
                                 metric_type=self.metric_type,
                                 extra_metas={
                                     "weight_mode": self.weight_mode,
                                     "class_weight": class_weight,
                                     "classes": classes,
                                     "sample_weight_name":
                                     self.sample_weight_name
                                 })

        self.callback_metric(metric_name=self.metric_name,
                             metric_namespace=self.metric_namespace,
                             metric_data=[Metric(self.metric_name, 0)])
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=self.metric_name,
                                     metric_meta=metric_meta)
 def _display_result(self, block_num=None):
     if block_num is None:
         self.callback_metric(metric_name=self.metric_name,
                              metric_namespace=self.metric_namespace,
                              metric_data=[Metric("Coverage", self.coverage),
                                           Metric("Block number", self.block_num)])
         self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                      metric_name=self.metric_name,
                                      metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION"))
     else:
         self.callback_metric(metric_name=self.metric_name,
                              metric_namespace=self.metric_namespace,
                              metric_data=[Metric("Coverage", self.coverage),
                                           Metric("Block number", block_num)])
         self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                      metric_name=self.metric_name,
                                      metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION"))
Ejemplo n.º 4
0
    def fit(self, data_inst, validate_data=None):

        # init binning obj
        self.aggregator = HomoBoostArbiterAggregator()
        self.binning_obj = HomoFeatureBinningServer()

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = HomoLabelEncoderArbiter().label_alignment()
            LOGGER.info('label mapping is {}'.format(label_mapping))
            self.booster_dim = len(
                label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        # sync start round and end round
        self.sync_start_round_and_end_round()

        LOGGER.info('begin to fit a boosting tree')
        self.preprocess()
        for epoch_idx in range(self.start_round, self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            for class_idx in range(self.booster_dim):
                model = self.fit_a_learner(epoch_idx, class_idx)

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.history_loss.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))
        self.postprocess()
        self.callback_list.on_train_end()
        self.set_summary(self.generate_summary())
Ejemplo n.º 5
0
    def record_step_best(self, step_best, host_mask, guest_mask, data_instances, model):
        metas = {"host_mask": host_mask.tolist(), "guest_mask": guest_mask.tolist(),
                 "score_name": self.score_name}
        metas["number_in"] = int(sum(host_mask) + sum(guest_mask))
        metas["direction"] = self.direction
        metas["n_count"] = int(self.n_count)

        host_anonym = [
            anonymous_generator.generate_anonymous(
                fid=i,
                role='host',
                model=model) for i in range(
                len(host_mask))]
        guest_anonym = [
            anonymous_generator.generate_anonymous(
                fid=i,
                role='guest',
                model=model) for i in range(
                len(guest_mask))]
        metas["host_features_anonym"] = host_anonym
        metas["guest_features_anonym"] = guest_anonym

        model_info = self.models_trained[step_best]
        loss = model_info.get_loss()
        ic_val = model_info.get_score()
        metas["loss"] = loss
        metas["current_ic_val"] = ic_val
        metas["fit_intercept"] = model.fit_intercept

        model_key = model_info.get_key()
        model_dict = self._get_model(model_key)

        if self.role != consts.ARBITER:
            all_features = data_instances.schema.get('header')
            metas["all_features"] = all_features
            metas["to_enter"] = self.get_to_enter(host_mask, guest_mask, all_features)
            model_param = list(model_dict.get('model').values())[0].get(
                model.model_param_name)
            param_dict = MessageToDict(model_param)
            metas["intercept"] = param_dict.get("intercept", None)
            metas["weight"] = param_dict.get("weight", {})
            metas["header"] = param_dict.get("header", [])
            if self.n_step == 0 and self.direction == "forward":
                metas["intercept"] = self.intercept
            self.update_summary_client(model, host_mask, guest_mask, all_features, host_anonym, guest_anonym)
        else:
            self.update_summary_arbiter(model, loss, ic_val)
        metric_name = f"stepwise_{self.n_step}"
        metric = [Metric(metric_name, float(self.n_step))]
        model.callback_metric(metric_name=metric_name, metric_namespace=self.metric_namespace, metric_data=metric)
        model.tracker.set_metric_meta(metric_name=metric_name, metric_namespace=self.metric_namespace,
                                      metric_meta=MetricMeta(name=metric_name, metric_type=self.metric_type,
                                                             extra_metas=metas))
        LOGGER.info(f"metric_name: {metric_name}, metas: {metas}")
        return
Ejemplo n.º 6
0
 def callback(self, metas):
     metric = [Metric(self.metric_name, 0)]
     self.callback_metric(metric_name=self.metric_name,
                          metric_namespace=self.metric_namespace,
                          metric_data=metric)
     self.tracker.set_metric_meta(metric_name=self.metric_name,
                                  metric_namespace=self.metric_namespace,
                                  metric_meta=MetricMeta(
                                      name=self.metric_name,
                                      metric_type=self.metric_type,
                                      extra_metas=metas))
Ejemplo n.º 7
0
 def __save_curve_data(self, x_axis_list, y_axis_list, metric_name,
                       metric_namespace):
     points = []
     for i, value in enumerate(x_axis_list):
         if isinstance(value, float):
             value = np.round(value, self.round_num)
         points.append((value, np.round(y_axis_list[i], self.round_num)))
     points.sort(key=lambda x: x[0])
     metric_points = [Metric(point[0], point[1]) for point in points]
     self.tracker.log_metric_data(metric_namespace, metric_name,
                                  metric_points)
Ejemplo n.º 8
0
    def callback_loss(self, iter_num, loss):
        metric_meta = MetricMeta(name='train',
                                 metric_type="LOSS",
                                 extra_metas={
                                     "unit_name": "iters",
                                 })

        self.callback_meta(metric_name='loss', metric_namespace='train', metric_meta=metric_meta)
        self.callback_metric(metric_name='loss',
                             metric_namespace='train',
                             metric_data=[Metric(iter_num, loss)])
Ejemplo n.º 9
0
    def callback_info(self):
        metric_meta = MetricMeta(
            name='train',
            metric_type=self.metric_type,
            extra_metas={"label_encoder": self.label_encoder})

        self.callback_metric(metric_name=self.metric_name,
                             metric_namespace=self.metric_namespace,
                             metric_data=[Metric(self.metric_name, 0)])
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=self.metric_name,
                                     metric_meta=metric_meta)
Ejemplo n.º 10
0
    def callback_dbi(self, iter_num, dbi):
        metric_meta = MetricMeta(name='train',
                                 metric_type="DBI",
                                 extra_metas={
                                     "unit_name": "iters",
                                 })

        self.callback_meta(metric_name='DBI',
                           metric_namespace='train',
                           metric_meta=metric_meta)
        self.callback_metric(metric_name='DBI',
                             metric_namespace='train',
                             metric_data=[Metric(iter_num, dbi)])
Ejemplo n.º 11
0
    def __save_single_value(self, result, metric_name, metric_namespace,
                            eval_name):

        metric_type = 'EVALUATION_SUMMARY'
        if eval_name in consts.ALL_CLUSTER_METRICS:
            metric_type = 'CLUSTERING_EVALUATION_SUMMARY'

        self.tracker.log_metric_data(
            metric_namespace, metric_name,
            [Metric(eval_name, np.round(result, self.round_num))])
        self.tracker.set_metric_meta(
            metric_namespace, metric_name,
            MetricMeta(name=metric_name, metric_type=metric_type))
Ejemplo n.º 12
0
    def callback_loss(self, iter_num, loss):
        # noinspection PyTypeChecker
        metric_meta = MetricMeta(
            name="train",
            metric_type="LOSS",
            extra_metas={
                "unit_name": "iters",
            },
        )

        self.callback_meta(
            metric_name="loss", metric_namespace="train", metric_meta=metric_meta
        )
        self.callback_metric(
            metric_name="loss",
            metric_namespace="train",
            metric_data=[Metric(iter_num, loss)],
        )
Ejemplo n.º 13
0
def server_callback_loss(self, iter_num, loss):
    # noinspection PyTypeChecker
    metric_meta = MetricMeta(
        name="train",
        metric_type="LOSS",
        extra_metas={
            "unit_name": "iters",
        },
    )

    self.callback_meta(metric_name="loss",
                       metric_namespace="train",
                       metric_meta=metric_meta)
    self.callback_metric(
        metric_name="loss",
        metric_namespace="train",
        metric_data=[Metric(iter_num, loss)],
    )

    self._summary["loss_history"].append(loss)
Ejemplo n.º 14
0
    def __sample(self, data_inst, sample_ids=None):
        """
        Random sample method, a line's occur probability is decide by fraction
            support down sample and up sample
                if use down sample: should give a float ratio between [0, 1]
                otherwise: should give a float ratio larger than 1.0

        Parameters
        ----------
        data_inst : Table
            The input data

        sample_ids : None or list
            if None, will sample data from the class instance's parameters,
            otherwise, it will be sample transform process, which means use the samples_ids to generate data

        Returns
        -------
        new_data_inst: Table
            the output sample data, same format with input

        sample_ids: list, return only if sample_ids is None


        """
        LOGGER.info("start to run random sampling")

        return_sample_ids = False
        if self.method == "downsample":
            if sample_ids is None:
                return_sample_ids = True
                idset = [
                    key for key, value in data_inst.mapValues(
                        lambda val: None).collect()
                ]
                if self.fraction < 0 or self.fraction > 1:
                    raise ValueError(
                        "sapmle fractions should be a numeric number between 0 and 1inclusive"
                    )

                sample_num = max(1, int(self.fraction * len(idset)))

                sample_ids = resample(idset,
                                      replace=False,
                                      n_samples=sample_num,
                                      random_state=self.random_state)

            sample_dtable = session.parallelize(zip(sample_ids,
                                                    range(len(sample_ids))),
                                                include_key=True,
                                                partition=data_inst.partitions)
            new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1)

            callback(self.tracker,
                     "random", [Metric("count", new_data_inst.count())],
                     summary_dict=self._summary_buf)

            if return_sample_ids:
                return new_data_inst, sample_ids
            else:
                return new_data_inst

        elif self.method == "upsample":
            data_set = list(data_inst.collect())
            idset = [key for (key, value) in data_set]
            id_maps = dict(zip(idset, range(len(idset))))

            if sample_ids is None:
                return_sample_ids = True
                if self.fraction <= 0:
                    raise ValueError(
                        "sapmle fractions should be a numeric number large than 0"
                    )

                sample_num = int(self.fraction * len(idset))
                sample_ids = resample(idset,
                                      replace=True,
                                      n_samples=sample_num,
                                      random_state=self.random_state)

            new_data = []
            for i in range(len(sample_ids)):
                index = id_maps[sample_ids[i]]
                new_data.append((i, data_set[index][1]))

            new_data_inst = session.parallelize(new_data,
                                                include_key=True,
                                                partition=data_inst.partitions)

            callback(self.tracker,
                     "random", [Metric("count", new_data_inst.count())],
                     summary_dict=self._summary_buf)

            if return_sample_ids:
                return new_data_inst, sample_ids
            else:
                return new_data_inst

        else:
            raise ValueError("random sampler not support method {} yet".format(
                self.method))
Ejemplo n.º 15
0
    def __sample(self, data_inst, sample_ids=None):
        """
        Stratified sample method, a line's occur probability is decide by fractions
            Input should be Table, every line should be an instance object with label
            To use this method, a list of ratio should be give, and the list length
                equals to the number of distinct labels
            support down sample and up sample
                if use down sample: should give a list of (category, ratio), where ratio is between [0, 1]
                otherwise: should give a list (category, ratio), where the float ratio should no less than 1.0


        Parameters
        ----------
        data_inst : Table
            The input data

        sample_ids : None or list
            if None, will sample data from the class instance's parameters,
            otherwise, it will be sample transform process, which means use the samples_ids the generate data

        Returns
        -------
        new_data_inst: Table
            the output sample data, sample format with input

        sample_ids: list, return only if sample_ids is None


        """

        LOGGER.info("start to run stratified sampling")
        return_sample_ids = False
        if self.method == "downsample":
            if sample_ids is None:
                idset = [[] for i in range(len(self.fractions))]
                for label, fraction in self.fractions:
                    if fraction < 0 or fraction > 1:
                        raise ValueError(
                            "sapmle fractions should be a numeric number between 0 and 1inclusive"
                        )

                return_sample_ids = True
                for key, inst in data_inst.collect():
                    label = inst.label
                    if label not in self.label_mapping:
                        raise ValueError(
                            "label not specify sample rate! check it please")
                    idset[self.label_mapping[label]].append(key)

                sample_ids = []

                callback_sample_metrics = []
                callback_original_metrics = []

                for i in range(len(idset)):
                    label_name = self.labels[i]
                    callback_original_metrics.append(
                        Metric(label_name, len(idset[i])))

                    if idset[i]:
                        sample_num = max(
                            1, int(self.fractions[i][1] * len(idset[i])))

                        _sample_ids = resample(idset[i],
                                               replace=False,
                                               n_samples=sample_num,
                                               random_state=self.random_state)

                        sample_ids.extend(_sample_ids)

                        callback_sample_metrics.append(
                            Metric(label_name, len(_sample_ids)))
                    else:
                        callback_sample_metrics.append(Metric(label_name, 0))

                random.shuffle(sample_ids)

                callback(self.tracker, "stratified", callback_sample_metrics,
                         callback_original_metrics, self._summary_buf)

            sample_dtable = session.parallelize(zip(sample_ids,
                                                    range(len(sample_ids))),
                                                include_key=True,
                                                partition=data_inst.partitions)
            new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1)

            if return_sample_ids:
                return new_data_inst, sample_ids
            else:
                return new_data_inst

        elif self.method == "upsample":
            data_set = list(data_inst.collect())
            ids = [key for (key, inst) in data_set]
            id_maps = dict(zip(ids, range(len(ids))))

            return_sample_ids = False

            if sample_ids is None:
                idset = [[] for i in range(len(self.fractions))]
                for label, fraction in self.fractions:
                    if fraction <= 0:
                        raise ValueError(
                            "sapmle fractions should be a numeric number greater than 0"
                        )

                for key, inst in data_set:
                    label = inst.label
                    if label not in self.label_mapping:
                        raise ValueError(
                            "label not specify sample rate! check it please")
                    idset[self.label_mapping[label]].append(key)

                return_sample_ids = True

                sample_ids = []
                callback_sample_metrics = []
                callback_original_metrics = []

                for i in range(len(idset)):
                    label_name = self.labels[i]
                    callback_original_metrics.append(
                        Metric(label_name, len(idset[i])))

                    if idset[i]:
                        sample_num = max(
                            1, int(self.fractions[i][1] * len(idset[i])))

                        _sample_ids = resample(idset[i],
                                               replace=True,
                                               n_samples=sample_num,
                                               random_state=self.random_state)

                        sample_ids.extend(_sample_ids)

                        callback_sample_metrics.append(
                            Metric(label_name, len(_sample_ids)))
                    else:
                        callback_sample_metrics.append(Metric(label_name, 0))

                random.shuffle(sample_ids)

                callback(self.tracker, "stratified", callback_sample_metrics,
                         callback_original_metrics, self._summary_buf)

            new_data = []
            for i in range(len(sample_ids)):
                index = id_maps[sample_ids[i]]
                new_data.append((i, data_set[index][1]))

            new_data_inst = session.parallelize(new_data,
                                                include_key=True,
                                                partition=data_inst.partitions)

            if return_sample_ids:
                return new_data_inst, sample_ids
            else:
                return new_data_inst

        else:
            raise ValueError(
                "Stratified sampler not support method {} yet".format(
                    self.method))
Ejemplo n.º 16
0
    def fit(self, data):
        # LOGGER.debug(f"fit receives data is {data}")
        if not isinstance(data, dict) or len(data) <= 1:
            raise ValueError(
                "Union module must receive more than one table as input.")
        empty_count = 0
        combined_table = None
        combined_schema = None
        metrics = []

        for (key, local_table) in data.items():
            LOGGER.debug("table to combine name: {}".format(key))
            num_data = local_table.count()
            LOGGER.debug("table count: {}".format(num_data))
            metrics.append(Metric(key, num_data))
            self.add_summary(key, num_data)

            if num_data == 0:
                LOGGER.warning("Table {} is empty.".format(key))
                if combined_table is None:
                    combined_table = local_table
                    combined_schema = local_table.schema
                empty_count += 1
                continue

            local_is_data_instance = self.check_is_data_instance(local_table)
            if self.is_data_instance is None or combined_table is None:
                self.is_data_instance = local_is_data_instance
            LOGGER.debug(f"self.is_data_instance is {self.is_data_instance}, "
                         f"local_is_data_instance is {local_is_data_instance}")
            if self.is_data_instance != local_is_data_instance:
                raise ValueError(
                    f"Cannot combine DataInstance and non-DataInstance object. Union aborted."
                )

            if self.is_data_instance:
                self.is_empty_feature = data_overview.is_empty_feature(
                    local_table)
                if self.is_empty_feature:
                    LOGGER.warning("Table {} has empty feature.".format(key))
                else:
                    self.check_schema_content(local_table.schema)

            if combined_table is None or combined_table.count() == 0:
                # first non-empty table to combine
                combined_table = local_table
                combined_schema = local_table.schema
                if self.keep_duplicate:
                    combined_table = combined_table.map(lambda k, v:
                                                        (f"{k}_{key}", v))
                    combined_table.schema = combined_schema
            else:
                self.check_id(local_table, combined_table)
                self.check_label_name(local_table, combined_table)
                self.check_header(local_table, combined_table)
                if self.keep_duplicate:
                    local_table = local_table.map(lambda k, v:
                                                  (f"{k}_{key}", v))

                combined_table = combined_table.union(local_table,
                                                      self._keep_first)

                combined_table.schema = combined_schema

            # only check feature length if not empty
            if self.is_data_instance and not self.is_empty_feature:
                self.feature_count = len(combined_schema.get("header"))
                # LOGGER.debug(f"feature count: {self.feature_count}")
                combined_table.mapValues(self.check_feature_length)

        if combined_table is None:
            LOGGER.warning(
                "All tables provided are empty or have empty features.")
            first_table = list(data.values())[0]
            combined_table = first_table.join(first_table)
        num_data = combined_table.count()
        metrics.append(Metric("Total", num_data))
        self.add_summary("Total", num_data)
        LOGGER.info(f"Result total data entry count: {num_data}")

        self.callback_metric(metric_name=self.metric_name,
                             metric_namespace=self.metric_namespace,
                             metric_data=metrics)
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=self.metric_name,
                                     metric_meta=MetricMeta(
                                         name=self.metric_name,
                                         metric_type=self.metric_type))

        LOGGER.info(
            "Union operation finished. Total {} empty tables encountered.".
            format(empty_count))

        return combined_table
Ejemplo n.º 17
0
    def fit(self, data_inst, validate_data=None):
        self.callback_list.on_train_begin(data_inst, validate_data)

        # collect data from table to form data loader
        if not self.component_properties.is_warm_start:
            self._build_model()
            cur_epoch = 0
        else:
            self.model.warm_start()
            self.callback_warm_start_init_iter(self.history_iter_epoch)
            cur_epoch = self.history_iter_epoch + 1

        self.prepare_batch_data(self.batch_generator, data_inst)
        if not self.input_shape:
            self.model.set_empty()

        self._set_loss_callback_info()
        while cur_epoch < self.epochs:
            self.iter_epoch = cur_epoch
            LOGGER.debug("cur epoch is {}".format(cur_epoch))
            self.callback_list.on_epoch_begin(cur_epoch)
            epoch_loss = 0

            for batch_idx in range(len(self.data_x)):
                # hetero NN model
                batch_loss = self.model.train(self.data_x[batch_idx], self.data_y[batch_idx], cur_epoch, batch_idx)

                epoch_loss += batch_loss

            epoch_loss /= len(self.data_x)

            LOGGER.debug("epoch {}' loss is {}".format(cur_epoch, epoch_loss))

            self.callback_metric("loss",
                                 "train",
                                 [Metric(cur_epoch, epoch_loss)])

            self.history_loss.append(epoch_loss)

            self.callback_list.on_epoch_end(cur_epoch)
            if self.callback_variables.stop_training:
                LOGGER.debug('early stopping triggered')
                break

            if self.hetero_nn_param.selector_param.method:
                # when use selective bp, loss converge will be disabled
                is_converge = False
            else:
                is_converge = self.converge_func.is_converge(epoch_loss)
            self._summary_buf["is_converged"] = is_converge
            self.transfer_variable.is_converge.remote(is_converge,
                                                      role=consts.HOST,
                                                      idx=0,
                                                      suffix=(cur_epoch,))

            if is_converge:
                LOGGER.debug("Training process is converged in epoch {}".format(cur_epoch))
                break

            cur_epoch += 1

        if cur_epoch == self.epochs:
            LOGGER.debug("Training process reach max training epochs {} and not converged".format(self.epochs))

        self.callback_list.on_train_end()
        # if self.validation_strategy and self.validation_strategy.has_saved_best_model():
        #     self.load_model(self.validation_strategy.cur_best_model)

        self.set_summary(self._get_model_summary())
Ejemplo n.º 18
0
    def fit(self, data_inst, validate_data=None):

        LOGGER.debug('in training, partitions is {}'.format(
            data_inst.partitions))
        LOGGER.info('start to fit a ftl model, '
                    'run mode is {},'
                    'communication efficient mode is {}'.format(
                        self.mode, self.comm_eff))

        self.check_host_number()

        data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data(
            self.init_intersect_obj(), data_inst, guest_side=True)
        self.input_dim = self.x_shape[0]

        # cache data_loader for faster validation
        self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader

        self.partitions = data_inst.partitions
        LOGGER.debug('self partitions is {}'.format(self.partitions))

        self.initialize_nn(input_shape=self.x_shape)
        self.feat_dim = self.nn._model.output_shape[1]
        self.constant_k = 1 / self.feat_dim
        self.callback_list.on_train_begin(train_data=data_inst,
                                          validate_data=validate_data)

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"unit_name": "iters"}))

        # compute intermediate result of first epoch
        self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(
            data_loader)

        for epoch_idx in range(self.epochs):

            LOGGER.debug('fitting epoch {}'.format(epoch_idx))

            self.callback_list.on_epoch_begin(epoch_idx)

            host_components = self.exchange_components(self.send_components,
                                                       epoch_idx=epoch_idx)

            loss = None

            for local_round_idx in range(self.local_round):

                if self.comm_eff:
                    LOGGER.debug(
                        'running local iter {}'.format(local_round_idx))

                grads = self.compute_backward_gradients(
                    host_components,
                    data_loader,
                    epoch_idx=epoch_idx,
                    local_round=local_round_idx)
                self.update_nn_weights(grads,
                                       data_loader,
                                       epoch_idx,
                                       decay=self.comm_eff)

                if local_round_idx == 0:
                    loss = self.compute_loss(
                        host_components, epoch_idx,
                        len(data_loader.get_overlap_indexes()))

                if local_round_idx + 1 != self.local_round:
                    self.phi, self.overlap_ua = self.compute_phi_and_overlap_ua(
                        data_loader)

            self.callback_metric("loss", "train", [Metric(epoch_idx, loss)])
            self.history_loss.append(loss)

            # updating variables for next epochs
            if epoch_idx + 1 == self.epochs:
                # only need to update phi in last epochs
                self.phi, _ = self.compute_phi_and_overlap_ua(data_loader)
            else:
                # compute phi, phi_product, overlap_ua etc. for next epoch
                self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(
                    data_loader)

            self.callback_list.on_epoch_end(epoch_idx)

            # check n_iter_no_change
            if self.n_iter_no_change is True:
                if self.check_convergence(loss):
                    self.sync_stop_flag(epoch_idx, stop_flag=True)
                    break
                else:
                    self.sync_stop_flag(epoch_idx, stop_flag=False)

            LOGGER.debug('fitting epoch {} done, loss is {}'.format(
                epoch_idx, loss))

        self.callback_list.on_train_end()
        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))

        self.set_summary(self.generate_summary())
        LOGGER.debug('fitting ftl model done')
Ejemplo n.º 19
0
    def fit(self, data_inst, validate_data=None):

        LOGGER.info('begin to fit a hetero boosting model, model is {}'.format(
            self.model_name))

        self.start_round = 0

        self.on_training = True

        self.data_inst = data_inst

        self.data_bin, self.bin_split_points, self.bin_sparse_points = self.prepare_data(
            data_inst)

        self.y = self.get_label(self.data_bin)

        if not self.is_warm_start:
            self.feature_name_fid_mapping = self.gen_feature_fid_mapping(
                data_inst.schema)
            self.classes_, self.num_classes, self.booster_dim = self.check_label(
            )
            self.loss = self.get_loss_function()
            self.y_hat, self.init_score = self.get_init_score(
                self.y, self.num_classes)
        else:
            classes_, num_classes, booster_dim = self.check_label()
            self.prepare_warm_start(data_inst, classes_)

        LOGGER.info('class index is {}'.format(self.classes_))

        self.sync_booster_dim()

        self.generate_encrypter()

        self.callback_list.on_train_begin(data_inst, validate_data)

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"unit_name": "iters"}))

        self.preprocess()

        for epoch_idx in range(self.start_round, self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            self.callback_list.on_epoch_begin(epoch_idx)

            for class_idx in range(self.booster_dim):

                # fit a booster
                model = self.fit_a_learner(epoch_idx, class_idx)

                booster_meta, booster_param = model.get_model()

                if booster_meta is not None and booster_param is not None:
                    self.booster_meta = booster_meta
                    self.boosting_model_list.append(booster_param)

                # update predict score
                cur_sample_weights = model.get_sample_weights()
                self.y_hat = self.get_new_predict_score(self.y_hat,
                                                        cur_sample_weights,
                                                        dim=class_idx)

            # compute loss
            loss = self.compute_loss(self.y_hat, self.y)
            self.history_loss.append(loss)
            LOGGER.info("round {} loss is {}".format(epoch_idx, loss))
            self.callback_metric("loss", "train", [Metric(epoch_idx, loss)])

            # check validation
            validation_strategy = self.callback_list.get_validation_strategy()
            if validation_strategy:
                validation_strategy.set_precomputed_train_scores(
                    self.score_to_predict_result(data_inst, self.y_hat))

            self.callback_list.on_epoch_end(epoch_idx)

            should_stop = False
            if self.n_iter_no_change and self.check_convergence(loss):
                should_stop = True
                self.is_converged = True
            self.sync_stop_flag(self.is_converged, epoch_idx)
            if self.stop_training or should_stop:
                break

        self.postprocess()
        self.callback_list.on_train_end()
        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))
        # get summary
        self.set_summary(self.generate_summary())