示例#1
0
 def test_predict_single_model_with_preprocess(self):
     layer_model = Layer([LinearRegression()], [MinMaxScaler()])
     X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
     y = np.dot(X, np.array([1, 2])) + 3
     layer_model.fit(X, y)
     result = layer_model.predict(np.array([[3, 5]]))
     assert result.shape == (1, 1)
     assert np.allclose(result, np.array([[16]]))
示例#2
0
    def test_predict_single_model_with_2_class_proba(self):
        layer_model = Layer([LogisticRegression(solver='liblinear')],
                            proba=True)
        X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
        y = np.array([1, 1, 0, 0])

        layer_model.fit(X, y)
        result = layer_model.predict(np.array([[3, 5]]))
        assert result.shape == (1, 2)
示例#3
0
 def test_predict_single_model_with_multi_class_proba(self):
     layer_model = Layer(
         [LogisticRegression(solver='lbfgs', multi_class='multinomial')],
         proba=True)
     X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
     y = np.array([1, 1, 0, 2])
     layer_model.fit(X, y)
     result = layer_model.predict(np.array([[3, 5]]))
     assert result.shape == (1, 3)
示例#4
0
 def test_predict_multiple_model(self):
     layer_model = Layer([LinearRegression(), LinearRegression()],
                         [None, MinMaxScaler()])
     X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
     y = np.dot(X, np.array([1, 2])) + 3
     layer_model.fit(X, y)
     result = layer_model.predict(np.array([[3, 5]]))
     assert result.shape == (1,2)
     assert np.allclose(result, np.array([[16, 16]]))
示例#5
0
 def test_using_proba_without_predict_proba_method(self):
     with pytest.warns(Warning) as record:
         layer_model = Layer([LinearRegression()], proba=True)
         X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
         y = np.dot(X, np.array([1, 2])) + 3
         layer_model.fit(X, y)
         result = layer_model.predict(np.array([[3, 5], [3, 5]]))
         assert result.shape == (2, 1)
         assert np.allclose(result, np.array([[16], [16]]))
         assert record