def _initialize_model(self, config): secureboosting_param = BoostingTreeParam() self.secureboosting_tree_param = ParamExtract.parse_param_from_config( secureboosting_param, config) self.model = HeteroSecureBoostingTreeHost( self.secureboosting_tree_param)
def test_param_embedding(self): boosting_tree_param = BoostingTreeParam() extractor = ParamExtract() boosting_tree_param = extractor.parse_param_from_config( boosting_tree_param, self.config_path) self.assertTrue(boosting_tree_param.tree_param.criterion_method == "test_decisiontree") self.assertTrue(boosting_tree_param.task_type == "test_boostingtree")
def setUp(self): self.init_param = InitParam() self.boosting_tree_param = BoostingTreeParam() self.config_dict = \ {"BoostingTreeParam": { "init_param": {"init_method": "test_init", "fit_intercept": False}, "tree_param": {"criterion_method": "test_decisiontree"}, "task_type": "test_boostingtree", "test_variable": "test"} }
def setUp(self): self.init_param = InitParam() self.boosting_tree_param = BoostingTreeParam() import json import time config_dict = \ {"InitParam": {"init_method": "test_init", "fit_intercept": False}, "DecisionTreeParam": {"criterion_method": "test_decisiontree"}, "BoostingTreeParam": {"task_type": "test_boostingtree"}} config_json = json.dumps(config_dict) timeid = int(time.time() * 1000) self.config_path = "param_config_test." + str(timeid) with open(self.config_path, "w") as fout: fout.write(config_json)
def test_param_embedding(self): boosting_tree_param = BoostingTreeParam() extractor = ParamExtract() boosting_tree_param = extractor.parse_param_from_config(boosting_tree_param, self.config_dict) print ("boosting_tree_param.tree_param.criterion_method {}".format(boosting_tree_param.tree_param.criterion_method)) self.assertTrue(boosting_tree_param.tree_param.criterion_method == "test_decisiontree")
def test_undefine_variable_extract(self): boosting_tree_param = BoostingTreeParam() extractor = ParamExtract() boosting_tree_param = extractor.parse_param_from_config(boosting_tree_param, self.config_dict) self.assertTrue(not hasattr(boosting_tree_param, "test_variable"))
def test_directly_extract(self): boosting_tree_param = BoostingTreeParam() extractor = ParamExtract() boosting_tree_param = extractor.parse_param_from_config(boosting_tree_param, self.config_dict) self.assertTrue(boosting_tree_param.task_type == "test_boostingtree")