コード例 #1
0
ファイル: test_sklearn.py プロジェクト: krfricke/xgboost_ray
    def test_save_load_model(self):
        self._init_ray()

        with TemporaryDirectory() as tempdir:
            model_path = os.path.join(tempdir, "digits.model")
            self.save_load_model(model_path)

        with TemporaryDirectory() as tempdir:
            model_path = os.path.join(tempdir, "digits.model.json")
            self.save_load_model(model_path)

        from sklearn.datasets import load_digits

        with TemporaryDirectory() as tempdir:
            model_path = os.path.join(tempdir, "digits.model.json")
            digits = load_digits(n_class=2)
            y = digits["target"]
            X = digits["data"]
            booster = xgb.train(
                {
                    "tree_method": "hist",
                    "objective": "binary:logistic"
                },
                dtrain=xgb.DMatrix(X, y),
                num_boost_round=4,
            )
            predt_0 = booster.predict(xgb.DMatrix(X))
            booster.save_model(model_path)
            cls = RayXGBClassifier()
            cls.load_model(model_path)

            proba = cls.predict_proba(X)
            assert proba.shape[0] == X.shape[0]
            assert proba.shape[1] == 2  # binary

            predt_1 = cls.predict_proba(X)[:, 1]
            assert np.allclose(predt_0, predt_1)

            cls = xgb.XGBModel()
            cls.load_model(model_path)
            predt_1 = cls.predict(X)
            assert np.allclose(predt_0, predt_1)
コード例 #2
0
ファイル: test_sklearn.py プロジェクト: krfricke/xgboost_ray
    def test_estimator_type(self):
        self._init_ray()

        assert RayXGBClassifier._estimator_type == "classifier"
        assert RayXGBRFClassifier._estimator_type == "classifier"
        assert RayXGBRegressor._estimator_type == "regressor"
        assert RayXGBRFRegressor._estimator_type == "regressor"
        assert RayXGBRanker._estimator_type == "ranker"

        from sklearn.datasets import load_digits

        X, y = load_digits(n_class=2, return_X_y=True)
        cls = RayXGBClassifier(n_estimators=2).fit(X, y)
        with tempfile.TemporaryDirectory() as tmpdir:
            path = os.path.join(tmpdir, "cls.json")
            cls.save_model(path)

            reg = RayXGBRegressor()
            with self.assertRaises(TypeError):
                reg.load_model(path)

            cls = RayXGBClassifier()
            cls.load_model(path)  # no error
コード例 #3
0
ファイル: test_sklearn.py プロジェクト: krfricke/xgboost_ray
    def save_load_model(self, model_path):
        from sklearn.datasets import load_digits
        from sklearn.model_selection import KFold

        digits = load_digits(n_class=2)
        y = digits["target"]
        X = digits["data"]
        kf = KFold(n_splits=2, shuffle=True, random_state=self.rng)
        for train_index, test_index in kf.split(X, y):
            xgb_model = RayXGBClassifier(use_label_encoder=False).fit(
                X[train_index], y[train_index])
            xgb_model.save_model(model_path)

            xgb_model = RayXGBClassifier()
            xgb_model.load_model(model_path)

            assert xgb_model.use_label_encoder is False
            assert isinstance(xgb_model.classes_, np.ndarray)
            assert isinstance(xgb_model._Booster, xgb.Booster)

            preds = xgb_model.predict(X[test_index])
            labels = y[test_index]
            err = sum(1 for i in range(len(preds))
                      if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
            assert err < 0.1
            assert xgb_model.get_booster().attr("scikit_learn") is None

            # test native booster
            preds = xgb_model.predict(X[test_index], output_margin=True)
            booster = xgb.Booster(model_file=model_path)
            predt_1 = booster.predict(
                xgb.DMatrix(X[test_index]), output_margin=True)
            assert np.allclose(preds, predt_1)

            with self.assertRaises(TypeError):
                xgb_model = xgb.XGBModel()
                xgb_model.load_model(model_path)