示例#1
0
    def get_model_param(self):
        model_param = BoostingTreeModelParam()
        model_param.tree_num = len(list(self.trees_))
        model_param.tree_dim = self.tree_dim
        model_param.trees_.extend(self.trees_)
        model_param.init_score.extend(self.init_score)
        model_param.losses.extend(self.history_loss)
        model_param.classes_.extend(map(str, self.classes_))
        model_param.num_classes = self.num_classes

        feature_importances = list(self.get_feature_importance().items())
        feature_importances = sorted(feature_importances,
                                     key=itemgetter(1),
                                     reverse=True)
        feature_importance_param = []
        for (sitename, fid), _importance in feature_importances:
            feature_importance_param.append(
                FeatureImportanceInfo(sitename=sitename,
                                      fid=fid,
                                      importance=_importance))
        model_param.feature_importances.extend(feature_importance_param)

        model_param.feature_name_fid_mapping.update(
            self.feature_name_fid_mapping)

        param_name = "HeteroSecureBoostingTreeGuestParam"

        return param_name, model_param
示例#2
0
    def get_model_param(self):
        model_param = BoostingTreeModelParam()
        model_param.tree_num = len(list(self.learnt_tree_param))
        model_param.tree_dim = self.tree_dim
        model_param.trees_.extend(self.learnt_tree_param)
        model_param.init_score.extend(self.init_score)
        model_param.losses.extend(self.local_loss_history)
        model_param.classes_.extend(map(str, self.classes_))
        model_param.num_classes = self.num_classes
        model_param.best_iteration = -1

        feature_importance = list(self.get_feature_importance().items())
        feature_importance = sorted(feature_importance,
                                    key=itemgetter(1),
                                    reverse=True)
        feature_importance_param = []
        for fid, _importance in feature_importance:
            feature_importance_param.append(
                FeatureImportanceInfo(sitename=self.role,
                                      fid=fid,
                                      importance=_importance))
        model_param.feature_importances.extend(feature_importance_param)

        model_param.feature_name_fid_mapping.update(
            self.feature_name_fid_mapping)

        param_name = "HomoSecureBoostingTreeGuestParam"

        return param_name, model_param
示例#3
0
    def get_model_param(self):
        model_param = BoostingTreeModelParam()
        model_param.tree_num = len(list(self.boosting_model_list))
        model_param.tree_dim = self.booster_dim
        model_param.trees_.extend(self.boosting_model_list)
        model_param.init_score.extend(self.init_score)
        model_param.classes_.extend(map(str, self.classes_))
        model_param.num_classes = self.num_classes
        model_param.best_iteration = -1
        model_param.model_name = consts.HOMO_SBT

        feature_importance = list(self.feature_importance.items())
        feature_importance = sorted(feature_importance,
                                    key=itemgetter(1),
                                    reverse=True)
        feature_importance_param = []
        for fid, importance in feature_importance:
            feature_importance_param.append(
                FeatureImportanceInfo(
                    fid=fid,
                    fullname=self.feature_name_fid_mapping[fid],
                    sitename=self.role,
                    importance=importance.importance,
                    importance2=importance.importance_2,
                    main=importance.main_type))

        model_param.feature_importances.extend(feature_importance_param)

        model_param.feature_name_fid_mapping.update(
            self.feature_name_fid_mapping)

        param_name = "HomoSecureBoostingTreeGuestParam"

        return param_name, model_param
示例#4
0
    def get_model_param(self):

        model_param = BoostingTreeModelParam()
        model_param.tree_num = len(self.boosting_model_list)
        model_param.tree_dim = self.booster_dim
        model_param.trees_.extend(self.boosting_model_list)
        model_param.init_score.extend(self.init_score)
        model_param.losses.extend(self.history_loss)
        model_param.classes_.extend(map(str, self.classes_))
        model_param.num_classes = self.num_classes
        if self.boosting_strategy == consts.STD_TREE:
            model_param.model_name = consts.HETERO_SBT
        elif self.boosting_strategy == consts.LAYERED_TREE:
            model_param.model_name = consts.HETERO_FAST_SBT_LAYERED
        elif self.boosting_strategy == consts.MIX_TREE:
            model_param.model_name = consts.HETERO_FAST_SBT_MIX
        model_param.best_iteration = self.callback_variables.best_iteration

        feature_importances = list(self.feature_importances_.items())
        feature_importances = sorted(feature_importances,
                                     key=itemgetter(1),
                                     reverse=True)
        feature_importance_param = []

        for (sitename, fid), importance in feature_importances:
            if consts.GUEST in sitename:
                fullname = self.feature_name_fid_mapping[fid]
            else:
                role_name, party_id = sitename.split(':')
                fullname = generate_anonymous(fid=fid,
                                              party_id=party_id,
                                              role=role_name)

            feature_importance_param.append(
                FeatureImportanceInfo(
                    sitename=sitename,  # sitename to distinguish sites
                    fid=fid,
                    importance=importance.importance,
                    fullname=fullname,
                    importance2=importance.importance_2,
                    main=importance.main_type))
        model_param.feature_importances.extend(feature_importance_param)
        model_param.feature_name_fid_mapping.update(
            self.feature_name_fid_mapping)
        model_param.tree_plan.extend(plan.encode_plan(self.tree_plan))
        param_name = consts.HETERO_SBT_GUEST_MODEL + "Param"

        return param_name, model_param
示例#5
0
    def get_model_param(self):

        model_param = BoostingTreeModelParam()
        model_param.tree_num = len(self.boosting_model_list)
        model_param.tree_dim = self.booster_dim
        model_param.trees_.extend(self.boosting_model_list)
        model_param.init_score.extend(self.init_score)
        model_param.losses.extend(self.history_loss)
        model_param.classes_.extend(map(str, self.classes_))
        model_param.num_classes = self.num_classes
        model_param.model_name = consts.HETERO_SBT
        model_param.best_iteration = -1 if self.validation_strategy is None else self.validation_strategy.best_iteration

        feature_importances = list(self.feature_importances_.items())
        feature_importances = sorted(feature_importances,
                                     key=itemgetter(1),
                                     reverse=True)
        feature_importance_param = []
        for (sitename, fid), _importance in feature_importances:
            if consts.GUEST in sitename:
                fullname = self.feature_name_fid_mapping[fid]
            else:
                role_name, party_id = sitename.split(':')
                fullname = generate_anonymous(fid=fid,
                                              party_id=party_id,
                                              role=role_name)
            feature_importance_param.append(
                FeatureImportanceInfo(sitename=sitename,
                                      fid=fid,
                                      importance=_importance,
                                      fullname=fullname))
        model_param.feature_importances.extend(feature_importance_param)

        model_param.feature_name_fid_mapping.update(
            self.feature_name_fid_mapping)

        param_name = "HeteroSecureBoostingTreeGuestParam"

        return param_name, model_param