예제 #1
0
파일: test_model.py 프로젝트: ityutin/trava
def test_get_model(mocker, raw_model, X, y, needs_proba):
    model = TravaModel(raw_model=raw_model, model_id=model_id)

    assert model.get_model(for_train=True) == raw_model
    assert model.get_model(for_train=False) == raw_model

    y_predict_proba = mocker.Mock()
    if needs_proba:
        raw_model.predict_proba.return_value = y_predict_proba

    y_pred = mocker.Mock()
    raw_model.predict.return_value = y_pred

    model.fit(X=X, y=y)
    model.predict(X=X, y=y)

    model.unload_model()

    train_cached_model = model.get_model(for_train=True)
    test_cached_model = model.get_model(for_train=False)

    assert train_cached_model != raw_model
    assert test_cached_model != raw_model

    assert train_cached_model.predict(X) == y_pred
    if needs_proba:
        assert train_cached_model.predict_proba(X) == y_predict_proba
예제 #2
0
파일: test_model.py 프로젝트: ityutin/trava
def test_all_y(mocker, raw_model, model_id, X, y, fit_params, predict_params, needs_proba):
    model = TravaModel(raw_model=raw_model, model_id=model_id)

    predict_proba_train = mocker.Mock()
    if needs_proba:
        raw_model.predict_proba.return_value = predict_proba_train
    y_pred_train = mocker.Mock()
    raw_model.predict.return_value = y_pred_train
    model.fit(X=X, y=y, fit_params=fit_params, predict_params=predict_params)

    predict_proba_test = mocker.Mock()
    if needs_proba:
        raw_model.predict_proba.return_value = predict_proba_test
    y_pred_test = mocker.Mock()
    raw_model.predict.return_value = y_pred_test
    X_test = mocker.Mock()
    y_test = mocker.Mock()
    model.predict(X=X_test, y=y_test)

    assert model.y_pred(for_train=True) == y_pred_train
    assert model.y_pred(for_train=False) == y_pred_test
    assert model.y(for_train=True) == y
    assert model.y(for_train=False) == y_test
    if needs_proba:
        assert model.y_pred_proba(for_train=True) == predict_proba_train
        assert model.y_pred_proba(for_train=False) == predict_proba_test
예제 #3
0
 def _fit(self, trava_model: TravaModel, X, y, fit_params: dict,
          predict_params: dict):
     """
     If you want to control the fit process
     """
     trava_model.fit(X=X,
                     y=y,
                     fit_params=fit_params,
                     predict_params=predict_params)
예제 #4
0
    def _fit(self, trava_model: TravaModel, X, y, fit_params: dict,
             predict_params: dict):
        if not self._is_raw_model_ready:
            trava_model.fit(X=X,
                            y=y,
                            fit_params=fit_params,
                            predict_params=predict_params)

            for group_model in self._group_models:
                if group_model != trava_model:
                    trava_model.copy(existing_model=group_model, only_fit=True)

            self._is_raw_model_ready = True
예제 #5
0
파일: test_model.py 프로젝트: ityutin/trava
def test_fit(mocker, raw_model, model_id, X, y, fit_params, predict_params, needs_proba):
    if needs_proba:
        predict_proba = mocker.Mock()
        raw_model.predict_proba.return_value = predict_proba

    y_pred = mocker.Mock()
    raw_model.predict.return_value = y_pred

    model = TravaModel(raw_model=raw_model, model_id=model_id)
    model.fit(X=X, y=y, fit_params=fit_params, predict_params=predict_params)

    raw_model.fit.assert_called_once_with(X, y, **fit_params)
    raw_model.predict.assert_called_once_with(X, **predict_params)

    if needs_proba:
        raw_model.predict_proba.assert_called_with(X, **predict_params)

    assert model.fit_time
예제 #6
0
파일: test_model.py 프로젝트: ityutin/trava
def test_fit_time(mocker, raw_model, model_id, X, y):
    model = TravaModel(raw_model=raw_model, model_id=model_id)
    assert not model.fit_time
    model.fit(X=X, y=y)
    assert model.fit_time
예제 #7
0
파일: test_model.py 프로젝트: ityutin/trava
def test_predict_params(mocker, raw_model, model_id, X, y, fit_params, predict_params):
    model = TravaModel(raw_model=raw_model, model_id=model_id)
    model.fit(X=X, y=y, fit_params=fit_params, predict_params=predict_params)

    assert model.predict_params == predict_params