예제 #1
0
    def test_use_first_metric_only(self):

        def evaluate(param, early_stopping_rounds, use_first_metric_only):

            eval_type = param.eval_type
            metric_list = param.metrics
            first_metric = None

            if early_stopping_rounds and use_first_metric_only and len(metric_list) != 0:

                single_metric_list = None
                if eval_type == consts.BINARY:
                    single_metric_list = consts.BINARY_SINGLE_VALUE_METRIC
                elif eval_type == consts.REGRESSION:
                    single_metric_list = consts.REGRESSION_SINGLE_VALUE_METRICS
                elif eval_type == consts.MULTY:
                    single_metric_list = consts.MULTI_SINGLE_VALUE_METRIC

                for metric in metric_list:
                    if metric in single_metric_list:
                        first_metric = metric
                        break

            return first_metric

        param_0 = EvaluateParam(metrics=['roc', 'lift', 'ks', 'auc', 'gain'], eval_type='binary')
        param_1 = EvaluateParam(metrics=['acc', 'precision', 'auc'], eval_type='binary')
        param_2 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'recall', 'lift'], eval_type='binary')
        param_3 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'auc', 'recall'], eval_type='multi')

        print(evaluate(param_0, 10, True))
        print(evaluate(param_1, 10, True))
        print(evaluate(param_2, 10, True))
        print(evaluate(param_3, 10, True))
예제 #2
0
 def get_metrics_param(self):
     if self.task_type == consts.CLASSIFICATION:
         if self.num_classes == 2:
             return EvaluateParam(eval_type="binary",
                                  pos_label=self.classes_[1])
         else:
             return EvaluateParam(eval_type="multi")
     else:
         return EvaluateParam(eval_type="regression")
예제 #3
0
 def get_metrics_param(self):
     if self.task_type == consts.CLASSIFICATION:
         if self.num_label == 2:
             return EvaluateParam(eval_type="binary",
                                  pos_label=1, metrics=self.metrics)
         else:
             return EvaluateParam(eval_type="multi", metrics=self.metrics)
     else:
         return EvaluateParam(eval_type="regression", metrics=self.metrics)
예제 #4
0
파일: boosting.py 프로젝트: zeta1999/FATE
 def get_metrics_param(self):
     """
     this interface gives evaluation type. Will be called by validation strategy
     """
     if self.task_type == consts.CLASSIFICATION:
         if self.num_classes == 2:
             return EvaluateParam(eval_type="binary",
                                  pos_label=self.classes_[1], metrics=self.metrics)
         else:
             return EvaluateParam(eval_type="multi", metrics=self.metrics)
     else:
         return EvaluateParam(eval_type="regression", metrics=self.metrics)
예제 #5
0
 def __init__(self,
              method='train',
              train_input_table=None,
              train_input_namespace=None,
              model_table=None,
              model_namespace=None,
              predict_input_table=None,
              predict_input_namespace=None,
              predict_result_partition=1,
              predict_output_table=None,
              predict_output_namespace=None,
              evaluation_output_table=None,
              evaluation_output_namespace=None,
              data_input_table=None,
              data_input_namespace=None,
              intersect_data_output_table=None,
              intersect_data_output_namespace=None,
              dataio_param=DataIOParam(),
              predict_param=PredictParam(),
              evaluate_param=EvaluateParam(),
              do_cross_validation=False,
              work_mode=0,
              n_splits=5,
              need_intersect=True,
              need_sample=False,
              need_feature_selection=False,
              need_scale=False,
              one_vs_rest=False,
              need_one_hot=False):
     self.method = method
     self.train_input_table = train_input_table
     self.train_input_namespace = train_input_namespace
     self.model_table = model_table
     self.model_namespace = model_namespace
     self.predict_input_table = predict_input_table
     self.predict_input_namespace = predict_input_namespace
     self.predict_output_table = predict_output_table
     self.predict_output_namespace = predict_output_namespace
     self.predict_result_partition = predict_result_partition
     self.evaluation_output_table = evaluation_output_table
     self.evaluation_output_namespace = evaluation_output_namespace
     self.data_input_table = data_input_table
     self.data_input_namespace = data_input_namespace
     self.intersect_data_output_table = intersect_data_output_table
     self.intersect_data_output_namespace = intersect_data_output_namespace
     self.dataio_param = copy.deepcopy(dataio_param)
     self.do_cross_validation = do_cross_validation
     self.n_splits = n_splits
     self.work_mode = work_mode
     self.predict_param = copy.deepcopy(predict_param)
     self.evaluate_param = copy.deepcopy(evaluate_param)
     self.need_intersect = need_intersect
     self.need_sample = need_sample
     self.need_feature_selection = need_feature_selection
     self.need_scale = need_scale
     self.need_one_hot = need_one_hot
     self.one_vs_rest = one_vs_rest
예제 #6
0
 def __init__(self,
              n_splits=5,
              mode=consts.HETERO,
              role=consts.GUEST,
              shuffle=True,
              random_seed=1,
              evaluate_param=EvaluateParam(),
              need_cv=False):
     super(CrossValidationParam, self).__init__()
     self.n_splits = n_splits
     self.mode = mode
     self.role = role
     self.shuffle = shuffle
     self.random_seed = random_seed
     self.evaluate_param = copy.deepcopy(evaluate_param)
     self.need_cv = need_cv
예제 #7
0
 def get_metrics_param(self):
     return EvaluateParam(eval_type="regression", metrics=self.metrics)
예제 #8
0
 def get_metrics_param(self):
     return EvaluateParam(eval_type="binary", metrics=self.metrics)
예제 #9
0
 def get_metrics_param(self):
     return EvaluateParam(eval_type="binary", pos_label=1)
예제 #10
0
 def get_metrics_param(self):
     if self.need_one_vs_rest:
         eval_type = 'multi'
     else:
         eval_type = "binary"
     return EvaluateParam(eval_type=eval_type)
예제 #11
0
    def check(self):

        descr = "workflow param's "

        self.method = self.check_and_change_lower(self.method, [
            'train', 'predict', 'cross_validation', 'intersect', 'binning',
            'feature_select', 'one_vs_rest_train', "one_vs_rest_predict"
        ], descr)

        if self.method in ['train', 'binning', 'feature_select']:
            if type(self.train_input_table).__name__ != "str":
                raise ValueError(
                    "workflow param's train_input_table {} not supported, should be str type"
                    .format(self.train_input_table))

            if type(self.train_input_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's train_input_namespace {} not supported, should be str type"
                    .format(self.train_input_namespace))

        if self.method in ["train", "predict", "cross_validation"]:
            if type(self.model_table).__name__ != "str":
                raise ValueError(
                    "workflow param's model_table {} not supported, should be str type"
                    .format(self.model_table))

            if type(self.model_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's model_namespace {} not supported, should be str type"
                    .format(self.model_namespace))

        if self.method == 'predict':
            if type(self.predict_input_table).__name__ != "str":
                raise ValueError(
                    "workflow param's predict_input_table {} not supported, should be str type"
                    .format(self.predict_input_table))

            if type(self.predict_input_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's predict_input_namespace {} not supported, should be str type"
                    .format(self.predict_input_namespace))

            if type(self.predict_output_table).__name__ != "str":
                raise ValueError(
                    "workflow param's predict_output_table {} not supported, should be str type"
                    .format(self.predict_output_table))

            if type(self.predict_output_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's predict_output_namespace {} not supported, should be str type"
                    .format(self.predict_output_namespace))

        if self.method in ["train", "predict", "cross_validation"]:
            if type(self.predict_result_partition).__name__ != "int":
                raise ValueError(
                    "workflow param's predict_result_partition {} not supported, should be int type"
                    .format(self.predict_result_partition))

            if type(self.evaluation_output_table).__name__ != "str":
                raise ValueError(
                    "workflow param's evaluation_output_table {} not supported, should be str type"
                    .format(self.evaluation_output_table))

            if type(self.evaluation_output_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's evaluation_output_namespace {} not supported, should be str type"
                    .format(self.evaluation_output_namespace))

        if self.method == 'cross_validation':
            if type(self.data_input_table).__name__ != "str":
                raise ValueError(
                    "workflow param's data_input_table {} not supported, should be str type"
                    .format(self.data_input_table))

            if type(self.data_input_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's data_input_namespace {} not supported, should be str type"
                    .format(self.data_input_namespace))

            if type(self.n_splits).__name__ != "int":
                raise ValueError(
                    "workflow param's n_splits {} not supported, should be int type"
                    .format(self.n_splits))
            elif self.n_splits <= 0:
                raise ValueError(
                    "workflow param's n_splits must be greater or equal to 1")

        if self.intersect_data_output_table is not None:
            if type(self.intersect_data_output_table).__name__ != "str":
                raise ValueError(
                    "workflow param's intersect_data_output_table {} not supported, should be str type"
                    .format(self.intersect_data_output_table))

        if self.intersect_data_output_namespace is not None:
            if type(self.intersect_data_output_namespace).__name__ != "str":
                raise ValueError(
                    "workflow param's intersect_data_output_namespace {} not supported, should be str type"
                    .format(self.intersect_data_output_namespace))

        DataIOParam.check(self.dataio_param)

        if type(self.work_mode).__name__ != "int":
            raise ValueError(
                "workflow param's work_mode {} not supported, should be int type"
                .format(self.work_mode))
        elif self.work_mode not in [0, 1]:
            raise ValueError(
                "workflow param's work_mode must be 0 (represent to standalone mode) or 1 (represent to cluster mode)"
            )

        if self.method in ["train", "predict", "cross_validation"]:
            PredictParam.check(self.predict_param)
            EvaluateParam.check(self.evaluate_param)

        LOGGER.debug("Finish workerflow parameter check!")
        return True