Esempio n. 1
0
 def test_fit(self):
     auto_xgb_reg = AutoXGBRegressor(cpus_per_trial=2,
                                     name="auto_xgb_regressor",
                                     tree_method='hist')
     data, validation_data = get_data()
     auto_xgb_reg.fit(data=data,
                      validation_data=validation_data,
                      search_space=create_XGB_recipe(),
                      n_sampling=4,
                      epochs=1,
                      metric="mse")
     best_model = auto_xgb_reg.get_best_model()
     assert 5 <= best_model.model.n_estimators <= 10
     assert 2 <= best_model.model.max_depth <= 5
 def test_fit(self):
     auto_xgb_reg = AutoXGBRegressor(cpus_per_trial=2,
                                     name="auto_xgb_regressor",
                                     tree_method='hist')
     data, validation_data = get_data()
     auto_xgb_reg.fit(data=data,
                      validation_data=validation_data,
                      search_space=get_xgb_search_space(),
                      n_sampling=4,
                      epochs=1,
                      metric="mae")
     best_model = auto_xgb_reg.get_best_model()
     assert 5 <= best_model.n_estimators <= 10
     assert 2 <= best_model.max_depth <= 5
     best_config = auto_xgb_reg.get_best_config()
     assert all(k in best_config.keys()
                for k in get_xgb_search_space().keys())
Esempio n. 3
0
            max_depth=list(max_depth_range),
            lr=lr,
            min_child_weight=min_child_weight)
        search_alg = None
        search_alg_params = None
        scheduler = None
        scheduler_params = None

    auto_xgb_reg = AutoXGBRegressor(cpus_per_trial=2,
                                    name="auto_xgb_regressor",
                                    **config)
    auto_xgb_reg.fit(data=(X_train, y_train),
                     validation_data=(X_val, y_val),
                     metric="rmse",
                     n_sampling=recipe.num_samples,
                     search_space=recipe.search_space(),
                     search_alg=search_alg,
                     search_alg_params=None,
                     scheduler=scheduler,
                     scheduler_params=scheduler_params)

    print("Training completed.")
    best_model = auto_xgb_reg.get_best_model()
    y_hat = best_model.predict(X_val)

    rmse = best_model.evaluate(X_val, y_val, metrics=["rmse"])
    print("Evaluate: the square root of mean square error is", rmse)

    ray_ctx.stop()
    sc.stop()