예제 #1
0
 def test_lasso_admm_zero(self):
     # Check that lasso by admm can handle zero data without crashing
     X = [[0], [0], [0]]
     y = [0, 0, 0]
     clf = LassoADMM(alpha=0.1).fit(X, y)
     pred = clf.predict([[1], [2], [3]])
     assert_array_almost_equal(clf.coef_, [0])
     assert_array_almost_equal(pred, [0, 0, 0])
예제 #2
0
    def test_lasso_admm(self):
        X, y, X_test, y_test = build_dataset()

        clf = LassoADMM(alpha=0.05, tol=1e-8).fit(X, y)
        self.assertGreater(clf.score(X_test, y_test), 0.99)
        self.assertLess(clf.n_iter_, 150)

        clf = LassoADMM(alpha=0.05, fit_intercept=False).fit(X, y)
        self.assertGreater(clf.score(X_test, y_test), 0.99)

        # normalize doesn't seem to work well
        clf = LassoADMM(alpha=0.144, rho=0.1, normalize=True).fit(X, y)
        self.assertGreater(clf.score(X_test, y_test), 0.60)
예제 #3
0
    def test_lasso_admm_toy_multi(self):
        # for issue #39
        X = np.eye(4)
        y = np.array([[1, 1, 0],
                      [1, 0, 1],
                      [0, 1, 0],
                      [0, 0, 1]])

        clf = LassoADMM(alpha=0.05, tol=1e-8).fit(X, y)
        assert_array_almost_equal(clf.coef_[0], [0.29999988, 0.29999988, -0.29999988, -0.29999988], decimal=3)
        assert_array_almost_equal(clf.coef_[1], [0.29999988, -0.29999988, 0.29999988, -0.29999988], decimal=3)
        assert_array_almost_equal(clf.coef_[2], [-0.29999988, 0.29999988, -0.29999988, 0.29999988], decimal=3)
예제 #4
0
파일: hakarus.py 프로젝트: wasa1999/spm
from spmimage.linear_model import LassoADMM

if __name__ == '__main__':
    data = pd.DataFrame(pd.read_csv('~/Downloads/data.csv'), dtype='float')
    scores = np.zeros(0)
    param_list = [0.001, 0.01, 0.1, 1, 10, 100]
    scoresbyparam = np.zeros(0)
    kf = KFold(n_splits=2)
    for l in param_list:
        print("current lambd is " + str(l))
        for train, test in kf.split(data):
            train_x = data.iloc[train, 2:7]
            train_y = data.iloc[train, 1]
            test_x = data.iloc[test, 2:7]
            test_y = data.iloc[test, 1]
            model = LassoADMM(l)
            model.fit(train_x, train_y)
            pre = model.predict(test_x)
            rmse = np.sqrt(mean_squared_error(test_y, pre))
            mae = mean_absolute_error(test_y, pre)
            score = 1.253 - (rmse / mae)
            print("score : " + str(score))
            print("coef_ : " + str(model.coef_))
            scores = np.append(scores, score)
        scoresbyparam = np.append(scoresbyparam, np.mean(scores))
    min = np.min(scoresbyparam)
    min_index = np.argmin(scoresbyparam)
    print("best lambd is " + str(param_list[min_index]))
    model = LassoADMM(param_list[min_index])
    X_train, X_test, y_train, y_test = train_test_split(data.iloc[:, 2:7],
                                                        data.iloc[:, 1],
예제 #5
0
    def test_lasso_admm_toy(self):
        # Test LassoADMM for various parameters of alpha and rho, using
        # the same test case as Lasso implementation of sklearn.
        # (see https://github.com/scikit-learn/scikit-learn/blob/master
        #               /sklearn/linear_model/tests/test_coordinate_descent.py)
        # Actually, the parameters alpha = 0 should not be allowed. However,
        # we test it as a border case.
        # WARNING:
        #   LassoADMM can't check the case which is not converged
        #   because LassoADMM doesn't check dual gap yet.
        #   This problem will be fixed in future.

        X = np.array([[-1.], [0.], [1.]])
        Y = [-1, 0, 1]  # just a straight line
        T = [[2.], [3.], [4.]]  # test sample

        clf = LassoADMM(alpha=1e-8)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [1], decimal=3)
        assert_array_almost_equal(pred, [2, 3, 4], decimal=3)

        clf = LassoADMM(alpha=0.1)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [.85], decimal=3)
        assert_array_almost_equal(pred, [1.7, 2.55, 3.4], decimal=3)

        clf = LassoADMM(alpha=0.5)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [.254], decimal=3)
        assert_array_almost_equal(pred, [0.508, 0.762, 1.016], decimal=3)

        clf = LassoADMM(alpha=1)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [.0], decimal=3)
        assert_array_almost_equal(pred, [0, 0, 0], decimal=3)

        # this is the same test case as the case alpha=1e-8
        # because the default rho parameter equals 1.0
        clf = LassoADMM(alpha=1e-8, rho=1.0)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [1], decimal=3)
        assert_array_almost_equal(pred, [2, 3, 4], decimal=3)

        clf = LassoADMM(alpha=0.5, rho=0.3, max_iter=50)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [0.249], decimal=3)
        assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3)

        # tests for max_iter parameter(default = 1000)
        clf = LassoADMM(alpha=0.5, rho=0.3, max_iter=100)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [0.249], decimal=3)
        assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3)

        clf = LassoADMM(alpha=0.5, rho=0.3, max_iter=500)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [0.249], decimal=3)
        assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3)

        clf = LassoADMM(alpha=0.5, rho=0.3, max_iter=1000)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [0.249], decimal=3)
        assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3)

        clf = LassoADMM(alpha=0.5, rho=0.5)
        clf.fit(X, Y)
        pred = clf.predict(T)
        assert_array_almost_equal(clf.coef_, [0.249], decimal=3)
        assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3)