示例#1
0
 def get_scale_param(self):
     component_param = {
         "method": "standard_scale",
         "mode": "normal",
         "scale_col_indexes": []
     }
     scale_param = ScaleParam()
     param_extracter = ParamExtract()
     param_extracter.parse_param_from_config(scale_param, component_param)
     print("scale_param:{}".format(type(scale_param)))
     return scale_param
示例#2
0
 def get_scale_param(self):
     component_param = {
         "method": "standard_scale",
         "mode": "normal",
         "area": "all",
         "scale_column_idx": []
     }
     scale_param = ScaleParam()
     param_extracter = ParamExtract()
     param_extracter.parse_param_from_config(scale_param, component_param)
     return scale_param
示例#3
0
 def get_scale_param(self):
     component_param = {
         "method": "standard_scale",
         "mode": "normal",
         "scale_col_indexes": [],
         "with_mean": True,
         "with_std": True,
     }
     scale_param = ScaleParam()
     param_extracter = ParamExtract()
     param_extracter.parse_param_from_config(scale_param, component_param)
     return scale_param
示例#4
0
    def _init_runtime_parameters(self, component_parameters):
        param_extracter = ParamExtract()
        param = param_extracter.parse_param_from_config(
            self.model_param, component_parameters)
        param.check()
        self._init_model(param)
        try:
            need_cv = param.cv_param.need_cv
        except AttributeError:
            need_cv = False
        self.need_cv = need_cv
        try:
            need_run = param.need_run
        except AttributeError:
            need_run = True
        self.need_run = need_run

        try:
            need_one_vs_rest = param.one_vs_rest_param.need_one_vs_rest
        except AttributeError:
            need_one_vs_rest = False
        self.need_one_vs_rest = need_one_vs_rest

        LOGGER.debug("need_run: {}, need_cv: {}".format(
            self.need_run, self.need_cv))
 def test_directly_extract(self):
     param_obj = FeatureBinningParam()
     extractor = ParamExtract()
     param_obj = extractor.parse_param_from_config(param_obj,
                                                   self.config_json)
     self.assertTrue(param_obj.method == "quantile")
     self.assertTrue(param_obj.transform_param.transform_type == 'bin_num')
示例#6
0
 def _init_runtime_parameters(self, component_parameters):
     param_extractor = ParamExtract()
     param = param_extractor.parse_param_from_config(self.model_param, component_parameters)
     param.check()
     self.role = self.component_properties.parse_component_param(component_parameters, param).role
     self._init_model(param)
     return param
示例#7
0
 def test_param_embedding(self):
     boosting_tree_param = BoostingTreeParam()
     extractor = ParamExtract()
     boosting_tree_param = extractor.parse_param_from_config(
         boosting_tree_param, self.config_path)
     self.assertTrue(boosting_tree_param.tree_param.criterion_method ==
                     "test_decisiontree")
     self.assertTrue(boosting_tree_param.task_type == "test_boostingtree")
示例#8
0
 def test_param_embedding(self):
     boosting_tree_param = HeteroSecureBoostParam()
     extractor = ParamExtract()
     boosting_tree_param = extractor.parse_param_from_config(
         boosting_tree_param, self.config_dict)
     print("boosting_tree_param.tree_param.criterion_method {}".format(
         boosting_tree_param.tree_param.criterion_method))
     self.assertTrue(boosting_tree_param.tree_param.criterion_method ==
                     "test_decisiontree")
示例#9
0
    def _check(self, Param, Checker):
        param_obj = Param()
        param_obj = ParamExtract.parse_param_from_config(
            param_obj, self.config_path)
        Checker.check_param(param_obj)

        self.all_checker.validate_restricted_param(param_obj,
                                                   self.validation_json,
                                                   self.param_classes)
示例#10
0
 def _initialize_model(self, config_path):
     neighbos_sampling_param = NeighborsSamplingParam()
     self.neighbos_sampling_param = ParamExtract.parse_param_from_config(
         neighbos_sampling_param, config_path)
     self.neighbors_sampler = NeighborsSamplingGuest(
         self.neighbos_sampling_param)
示例#11
0
 def test_undefine_variable_extract(self):
     boosting_tree_param = HeteroSecureBoostParam()
     extractor = ParamExtract()
     boosting_tree_param = extractor.parse_param_from_config(
         boosting_tree_param, self.config_dict)
     self.assertTrue(not hasattr(boosting_tree_param, "test_variable"))
示例#12
0
 def test_directly_extract(self):
     boosting_tree_param = HeteroSecureBoostParam()
     extractor = ParamExtract()
     boosting_tree_param = extractor.parse_param_from_config(
         boosting_tree_param, self.config_dict)
     self.assertTrue(boosting_tree_param.task_type == "test_boostingtree")
示例#13
0
 def _check(self, Param, Checker):
     param_obj = Param()
     param_obj = ParamExtract.parse_param_from_config(
         param_obj, self.config_path)
     Checker.check_param(param_obj)
示例#14
0
 def _init_runtime_parameters(self, component_parameters):
     param_extracter = ParamExtract()
     param = param_extracter.parse_param_from_config(
         self.model_param, component_parameters)
     self._init_model(param)
     return param
示例#15
0
 def _initialize_intersect(self, config):
     intersect_param = IntersectParam()
     self.intersect_param = ParamExtract.parse_param_from_config(
         intersect_param, config)
示例#16
0
 def test_directly_extract(self):
     init_param = InitParam()
     extractor = ParamExtract()
     init_param = extractor.parse_param_from_config(init_param,
                                                    self.config_path)
     self.assertTrue(init_param.init_method == "test_init")
示例#17
0
class FTLWorkFlow(object):
    def __init__(self):
        super(FTLWorkFlow, self).__init__()
        self.model = None
        self.job_id = None
        self.workflow_param = None
        self.param_extract = None

    def _initialize(self, config):
        LOGGER.debug("Get in base workflow initialize")
        self._initialize_model(config)
        self._initialize_workflow_param(config)

    def _initialize_model(self, config):
        LOGGER.debug("@ initialize model")
        ftl_model_param = FTLModelParam()
        ftl_local_model_param = LocalModelParam()
        ftl_data_param = FTLDataParam()
        ftl_valid_data_param = FTLValidDataParam()

        self.param_extract = ParamExtract()
        ftl_model_param = self.param_extract.parse_param_from_config(ftl_model_param, config)
        ftl_local_model_param = self.param_extract.parse_param_from_config(ftl_local_model_param, config)
        self.ftl_data_param = self.param_extract.parse_param_from_config(ftl_data_param, config)
        self.ftl_valid_data_param = self.param_extract.parse_param_from_config(ftl_valid_data_param, config)
        self.ftl_transfer_variable = HeteroFTLTransferVariable()

        FTLModelParam.check(ftl_model_param)
        LocalModelParam.check(ftl_local_model_param)
        FTLDataParam.check(self.ftl_data_param)
        FTLValidDataParam.check(self.ftl_valid_data_param)

        self._do_initialize_model(ftl_model_param, ftl_local_model_param, self.ftl_data_param)

    def _initialize_workflow_param(self, config):
        workflow_param = WorkFlowParam()
        self.workflow_param = self.param_extract.parse_param_from_config(workflow_param, config)
        workflow_param.check()

    def _get_transfer_variable(self):
        return self.ftl_transfer_variable

    def _get_data_model_param(self):
        return self.ftl_data_param

    def _get_valid_data_model_param(self):
        return self.ftl_valid_data_param

    def _do_initialize_model(self, ftl_model_param: FTLModelParam, ftl_local_model_param: LocalModelParam,
                             ftl_data_param: FTLDataParam):
        raise NotImplementedError("method init must be define")

    def save_eval_result(self, eval_data):
        LOGGER.info("@ save evaluation result to table with namespace: {0} and name: {1}".format(
            self.workflow_param.evaluation_output_namespace, self.workflow_param.evaluation_output_table))
        session.parallelize([eval_data],
                            include_key=False,
                            name=self.workflow_param.evaluation_output_table,
                            namespace=self.workflow_param.evaluation_output_namespace,
                            error_if_exist=False,
                            persistent=True
                            )

    def save_predict_result(self, predict_result):
        LOGGER.info("@ save prediction result to table with namespace: {0} and name: {1}".format(
            self.workflow_param.predict_output_namespace, self.workflow_param.predict_output_table))
        predict_result.save_as(self.workflow_param.predict_output_table, self.workflow_param.predict_output_namespace)

    def evaluate(self, eval_data):
        if eval_data is None:
            LOGGER.info("not eval_data!")
            return None

        eval_data_local = eval_data.collect()
        labels = []
        pred_prob = []
        pred_labels = []
        data_num = 0
        for data in eval_data_local:
            data_num += 1
            labels.append(data[1][0])
            pred_prob.append(data[1][1])
            pred_labels.append(data[1][2])

        labels = np.array(labels)
        pred_prob = np.array(pred_prob)
        pred_labels = np.array(pred_labels)

        evaluation_result = self.model.evaluate(labels, pred_prob, pred_labels,
                                                evaluate_param=self.workflow_param.evaluate_param)
        return evaluation_result

    def _init_argument(self):
        pass

    def gen_validation_data_instance(self, table, namespace):
        pass

    def gen_data_instance(self, table, namespace):
        pass

    def train(self, train_data_instance, validation_data=None):
        pass

    def predict(self, data_instance):
        pass

    def run(self):
        self._init_argument()
        if self.workflow_param.method == "train":
            data_instance = self.gen_data_instance(self.workflow_param.train_input_table,
                                                   self.workflow_param.train_input_namespace)

            valid_instance = self.gen_validation_data_instance(self.workflow_param.predict_input_table,
                                                               self.workflow_param.predict_input_namespace)
            self.train(data_instance, valid_instance)

        elif self.workflow_param.method == "predict":
            data_instance = self.gen_data_instance(self.workflow_param.predict_input_table,
                                                   self.workflow_param.predict_input_namespace)
            self.predict(data_instance)
        else:
            raise TypeError("method %s is not support yet" % (self.workflow_param.method))