Exemplo n.º 1
0
    def test_train_split(self):
        np.random.seed(42)
        y = np.random.randint(1, 3, 10)
        X = np.random.randint(1, 255, [10, 3], dtype=np.uint8)

        cv = cross_validation.LeaveOneOut(random_state=42)

        X_train, X_test, y_train, y_test = cross_validation.train_test_split(
            cv, X, y)
        assert (X_train.shape[0] + X_test.shape[0] == X.shape[0])
        assert (y_train.shape[0] + y_test.shape[0] == y.shape[0])
        assert (np.all(np.equal(y_test, np.array([1, 2]))))
    def test_compare_loo_kf(self):
        cv_loo = cross_validation.LeaveOneOut(random_state=12, verbose=2)
        cv_kf_as_loo = cross_validation.RandomStratifiedKFold(n_splits=False,
                                                              valid_size=1,
                                                              random_state=12,
                                                              verbose=2)
        for trvl_loo, trvl_kf in zip(cv_loo.split(X, y),
                                     cv_kf_as_loo.split(X, y)):
            assert (np.all(trvl_loo[0] == trvl_kf[0]))
            assert (np.all(trvl_loo[1] == trvl_kf[1]))
            assert (len(trvl_kf[1]) == n_class)
            assert (np.unique(y[trvl_kf[1]]).size == n_class)

        #to print extensions
        cv_loo.get_supported_extensions()
    def test_loo(self):
        for split in [False, 1, 2, 5]:

            cv = cross_validation.LeaveOneOut(n_repeats=split,
                                              random_state=split,
                                              verbose=split)
            if split == False:
                assert (cv.get_n_splits(X, y) == np.min(
                    np.unique(y, return_counts=True)[-1]))
            else:
                assert (cv.get_n_splits(X, y) == split)
            assert (cv.verbose == split)

            for tr, vl in cv.split(X, y):
                assert (tr.size == y.size - 5)
                assert (vl.size == 5)
                assert (len(vl) == 5)