示例#1
0
    def test_roll_size(self):
        X = range(7)
        foo = GapRollForward(gap_size=2, max_test_size=2, roll_size=1)
        splits = foo.split(X)
        assert_equal(foo.get_n_splits(X), 5)

        train, test = next(splits)
        assert_array_equal(train, [])
        assert_array_equal(test, [2, 3])

        train, test = next(splits)
        assert_array_equal(train, [0])
        assert_array_equal(test, [3, 4])

        train, test = next(splits)
        assert_array_equal(train, [0, 1])
        assert_array_equal(test, [4, 5])

        train, test = next(splits)
        assert_array_equal(train, [0, 1, 2])
        assert_array_equal(test, [5, 6])

        train, test = next(splits)
        assert_array_equal(train, [0, 1, 2, 3])
        assert_array_equal(test, [6])
示例#2
0
    def test_min_train_size(self):
        X = range(8)
        foo = GapRollForward(gap_size=2, max_test_size=2, min_train_size=2)
        splits = foo.split(X)
        assert_equal(foo.get_n_splits(X), 2)

        train, test = next(splits)
        assert_array_equal(train, [0, 1])
        assert_array_equal(test, [4, 5])

        train, test = next(splits)
        assert_array_equal(train, [0, 1, 2, 3])
        assert_array_equal(test, [6, 7])
示例#3
0
    def test_max_train_size(self):
        X = range(3)
        foo = GapRollForward(max_train_size=1)
        splits = foo.split(X)
        assert_equal(foo.get_n_splits(X), 3)

        train, test = next(splits)
        assert_array_equal(train, [])
        assert_array_equal(test, [0])

        train, test = next(splits)
        assert_array_equal(train, [0])
        assert_array_equal(test, [1])

        train, test = next(splits)
        assert_array_equal(train, [1])
        assert_array_equal(test, [2])
示例#4
0
    def test_max_test_size(self):
        X = range(5)
        foo = GapRollForward(max_test_size=2)
        splits = foo.split(X)
        assert_equal(foo.get_n_splits(X), 3)

        train, test = next(splits)
        assert_array_equal(train, [])
        assert_array_equal(test, [0, 1])

        train, test = next(splits)
        assert_array_equal(train, [0, 1])
        assert_array_equal(test, [2, 3])

        train, test = next(splits)
        assert_array_equal(train, [0, 1, 2, 3])
        assert_array_equal(test, [4])
示例#5
0
    def test_default_input(self):
        X = range(3)
        foo = GapRollForward()
        assert_equal(foo.get_n_splits(X), 3)
        splits = foo.split(X)

        train, test = next(splits)
        assert_array_equal(train, [])
        assert_array_equal(test, [0])

        train, test = next(splits)
        assert_array_equal(train, [0])
        assert_array_equal(test, [1])

        train, test = next(splits)
        assert_array_equal(train, [0, 1])
        assert_array_equal(test, [2])
示例#6
0
    def test_invalid_input(self):
        X = range(10)

        with pytest.raises(
            ValueError,
            match="No valid splits for the input arguments."
        ):
            next(GapRollForward(min_train_size=3, gap_size=7).split(X))