def predict_f_value(self, data_inst):
        LOGGER.info("predict tree f value")
        tree_dim = self.tree_dim
        self.F = data_inst.mapValues(lambda v: np.zeros(tree_dim))
        for i in range(len(self.trees_)):
            n_tree = self.trees_[i]
            for tidx in range(len(n_tree)):
                tree_inst = HeteroDecisionTreeGuest(self.tree_param)
                tree_inst.set_tree_model(n_tree[tidx])
                tree_inst.set_flowid(self.generate_flowid(i, tidx))

                predict_data = tree_inst.predict(data_inst)
                self.update_f_value(predict_data, tidx)
    def predict_f_value(self, data_inst):
        LOGGER.info("predict tree f value, there are {} trees".format(len(self.trees_)))
        tree_dim = self.tree_dim
        init_score = self.init_score
        self.F = data_inst.mapValues(lambda v: init_score)
        rounds = len(self.trees_) // self.tree_dim
        for i in range(rounds):
            for tidx in range(self.tree_dim):
                tree_inst = HeteroDecisionTreeGuest(self.tree_param)
                tree_inst.load_model(self.tree_meta, self.trees_[i * self.tree_dim + tidx])
                # tree_inst.set_tree_model(self.trees_[i * self.tree_dim + tidx])
                tree_inst.set_flowid(self.generate_flowid(i, tidx))

                predict_data = tree_inst.predict(data_inst)
                self.update_f_value(new_f=predict_data, tidx=tidx)
示例#3
0
    def fit(self, data_inst, validate_data=None):
        LOGGER.info("begin to train secureboosting guest model")
        self.gen_feature_fid_mapping(data_inst.schema)
        data_inst = self.data_alignment(data_inst)
        self.convert_feature_to_bin(data_inst)
        self.set_y()
        self.update_f_value()
        self.generate_encrypter()

        self.sync_tree_dim()

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"unit_name": "iters"}))

        validation_strategy = self.init_validation_strategy(
            data_inst, validate_data)

        for i in range(self.num_trees):
            self.compute_grad_and_hess()
            for tidx in range(self.tree_dim):
                tree_inst = HeteroDecisionTreeGuest(self.tree_param)

                tree_inst.set_inputinfo(self.data_bin,
                                        self.get_grad_and_hess(tidx),
                                        self.bin_split_points,
                                        self.bin_sparse_points)

                valid_features = self.sample_valid_features()
                tree_inst.set_valid_features(valid_features)
                tree_inst.set_encrypter(self.encrypter)
                tree_inst.set_encrypted_mode_calculator(
                    self.encrypted_calculator)
                tree_inst.set_flowid(self.generate_flowid(i, tidx))
                tree_inst.set_host_party_idlist(
                    self.component_properties.host_party_idlist)
                tree_inst.set_runtime_idx(
                    self.component_properties.local_partyid)

                tree_inst.fit()

                tree_meta, tree_param = tree_inst.get_model()
                self.trees_.append(tree_param)
                if self.tree_meta is None:
                    self.tree_meta = tree_meta
                self.update_f_value(new_f=tree_inst.predict_weights, tidx=tidx)
                self.update_feature_importance(
                    tree_inst.get_feature_importance())

            loss = self.compute_loss()
            self.history_loss.append(loss)
            LOGGER.info("round {} loss is {}".format(i, loss))

            LOGGER.debug("type of loss is {}".format(type(loss).__name__))
            self.callback_metric("loss", "train", [Metric(i, loss)])

            if validation_strategy:
                validation_strategy.validate(self, i)

            if self.n_iter_no_change is True:
                if self.check_convergence(loss):
                    self.sync_stop_flag(True, i)
                    break
                else:
                    self.sync_stop_flag(False, i)

        LOGGER.debug("history loss is {}".format(min(self.history_loss)))
        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))

        LOGGER.info("end to train secureboosting guest model")
    def fit(self, data_inst):
        LOGGER.info("begin to train secureboosting guest model")
        data_inst = self.data_alignment(data_inst)
        self.convert_feature_to_bin(data_inst)
        self.set_y()
        self.update_f_value()
        self.generate_encrypter()

        self.sync_tree_dim()

        for i in range(self.num_trees):
            # n_tree = []
            self.compute_grad_and_hess()
            for tidx in range(self.tree_dim):
                tree_inst = HeteroDecisionTreeGuest(self.tree_param)

                tree_inst.set_inputinfo(self.data_bin,
                                        self.get_grad_and_hess(tidx),
                                        self.bin_split_points,
                                        self.bin_sparse_points)

                valid_features = self.sample_valid_features()
                tree_inst.set_valid_features(valid_features)
                tree_inst.set_encrypter(self.encrypter)
                tree_inst.set_flowid(self.generate_flowid(i, tidx))

                tree_inst.fit()

                tree_meta, tree_param = tree_inst.get_model()
                self.trees_.append(tree_param)
                if self.tree_meta is None:
                    self.tree_meta = tree_meta
                # n_tree.append(tree_inst.get_tree_model())
                self.update_f_value(new_f=tree_inst.predict_weights, tidx=tidx)

            # self.trees_.append(n_tree)
            loss = self.compute_loss()
            self.history_loss.append(loss)
            LOGGER.info("round {} loss is {}".format(i, loss))

            if self.n_iter_no_change is True:
                if self.check_convergence(loss):
                    self.sync_stop_flag(True, i)
                    break
                else:
                    self.sync_stop_flag(False, i)

        LOGGER.info("end to train secureboosting guest model")
    def predict_f_value(self, data_inst, cache_dataset_key):
        LOGGER.info("predict tree f value, there are {} trees".format(
            len(self.trees_)))
        init_score = self.init_score

        last_round = self.predict_data_cache.predict_data_last_round(
            cache_dataset_key)
        rounds = len(self.trees_) // self.tree_dim
        if last_round == -1:
            self.predict_F = data_inst.mapValues(lambda v: init_score)
        else:
            LOGGER.debug("hit cache, cached round is {}".format(last_round))
            if last_round >= rounds - 1:
                LOGGER.debug(
                    "predict data cached, rounds is {}, total cached round is {}"
                    .format(rounds, last_round))

            self.predict_F = self.predict_data_cache.predict_data_at(
                cache_dataset_key, min(rounds - 1, last_round))

        self.sync_predict_start_round(last_round + 1)

        for i in range(last_round + 1, rounds):
            for tidx in range(self.tree_dim):
                tree_inst = HeteroDecisionTreeGuest(self.tree_param)
                tree_inst.load_model(self.tree_meta,
                                     self.trees_[i * self.tree_dim + tidx])
                # tree_inst.set_tree_model(self.trees_[i * self.tree_dim + tidx])
                tree_inst.set_flowid(self.generate_flowid(i, tidx))
                tree_inst.set_runtime_idx(
                    self.component_properties.local_partyid)
                tree_inst.set_host_party_idlist(
                    self.component_properties.host_party_idlist)

                predict_data = tree_inst.predict(data_inst)
                self.update_f_value(new_f=predict_data,
                                    tidx=tidx,
                                    mode="predict")

            self.predict_data_cache.add_data(cache_dataset_key, self.predict_F)