Пример #1
0
def test_warm_start():
    """Test the warm start parameter."""
    # Load data
    data = fetch_california_housing()
    X, y = data.data[:100], data.target[:100]
    # Initial fit
    estimator = KerasRegressor(
        model=dynamic_regressor,
        model__hidden_layer_sizes=(100,),
    )
    estimator.fit(X, y)
    model = estimator.model_

    # With warm start, successive calls to fit
    # should NOT create a new model
    estimator.set_params(warm_start=True)
    estimator.fit(X, y)
    assert model is estimator.model_

    # Without warm start, each call to fit
    # should create a new model instance
    estimator.set_params(warm_start=False)
    for _ in range(3):
        estimator.fit(X, y)
        assert model is not estimator.model_
        model = estimator.model_
Пример #2
0
    def test_current_epoch_property(self, warm_start, epochs_prefix):
        """Test the public current_epoch property
        that tracks the overall training epochs.

        The warm_start parameter should have
        NO impact on this behavior.

        The prefix should NOT have any impact on
        behavior. It is tested because the epochs
        param has special handling within param routing.
        """
        data = load_boston()
        X, y = data.data[:10], data.target[:10]
        epochs = 2
        partial_fit_iter = 3

        estimator = KerasRegressor(
            model=dynamic_regressor,
            loss=KerasRegressor.r_squared,
            model__hidden_layer_sizes=[],
            warm_start=warm_start,
        )
        estimator.set_params(**{epochs_prefix + "epochs": epochs})

        # Check that each partial_fit call trains for 1 epoch
        for k in range(1, partial_fit_iter):
            estimator.partial_fit(X, y)
            assert estimator.current_epoch == k

        # Check that fit calls still train for the number of
        # epochs specified in the constructor
        estimator.fit(X, y)
        assert estimator.current_epoch == epochs

        # partial_fit is able to resume from a non-zero epoch
        estimator.partial_fit(X, y)
        assert estimator.current_epoch == epochs + 1