Exemplo n.º 1
0
 def test_reg_build_fn(self):
     clf = wrappers.KerasRegressor(
         build_fn=build_fn_reg,
         hidden_dim=HIDDEN_DIM,
         batch_size=BATCH_SIZE,
         epochs=EPOCHS,
     )
     self.check_sample_weights_work(clf)
Exemplo n.º 2
0
    def test_regression_build_fn(self):
        """Tests for errors using KerasRegressor."""
        reg = wrappers.KerasRegressor(
            build_fn=build_fn_reg,
            hidden_dim=HIDDEN_DIM,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
        )

        assert_regression_works(reg)
Exemplo n.º 3
0
    def test_regression_build_fn(self):
        """Tests for errors using KerasRegressor."""
        reg = wrappers.KerasRegressor(
            build_fn=build_fn_reg,
            hidden_dim=HIDDEN_DIM,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
        )

        # create dataset
        X = np.random.rand(10, 20)

        with pytest.raises(NotFittedError):
            reg.predict(X)
Exemplo n.º 4
0
    def test_regression_class_build_fn(self):
        """Tests for errors using KerasRegressor implementing __call__."""
        class ClassBuildFnReg:
            def __call__(self, hidden_dim):
                return build_fn_reg(hidden_dim)

        reg = wrappers.KerasRegressor(
            build_fn=ClassBuildFnReg(),
            hidden_dim=HIDDEN_DIM,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
        )

        assert_regression_works(reg)