def test_roll_size(self): X = range(7) foo = GapRollForward(gap_size=2, max_test_size=2, roll_size=1) splits = foo.split(X) assert_equal(foo.get_n_splits(X), 5) train, test = next(splits) assert_array_equal(train, []) assert_array_equal(test, [2, 3]) train, test = next(splits) assert_array_equal(train, [0]) assert_array_equal(test, [3, 4]) train, test = next(splits) assert_array_equal(train, [0, 1]) assert_array_equal(test, [4, 5]) train, test = next(splits) assert_array_equal(train, [0, 1, 2]) assert_array_equal(test, [5, 6]) train, test = next(splits) assert_array_equal(train, [0, 1, 2, 3]) assert_array_equal(test, [6])
def test_min_train_size(self): X = range(8) foo = GapRollForward(gap_size=2, max_test_size=2, min_train_size=2) splits = foo.split(X) assert_equal(foo.get_n_splits(X), 2) train, test = next(splits) assert_array_equal(train, [0, 1]) assert_array_equal(test, [4, 5]) train, test = next(splits) assert_array_equal(train, [0, 1, 2, 3]) assert_array_equal(test, [6, 7])
def test_max_train_size(self): X = range(3) foo = GapRollForward(max_train_size=1) splits = foo.split(X) assert_equal(foo.get_n_splits(X), 3) train, test = next(splits) assert_array_equal(train, []) assert_array_equal(test, [0]) train, test = next(splits) assert_array_equal(train, [0]) assert_array_equal(test, [1]) train, test = next(splits) assert_array_equal(train, [1]) assert_array_equal(test, [2])
def test_max_test_size(self): X = range(5) foo = GapRollForward(max_test_size=2) splits = foo.split(X) assert_equal(foo.get_n_splits(X), 3) train, test = next(splits) assert_array_equal(train, []) assert_array_equal(test, [0, 1]) train, test = next(splits) assert_array_equal(train, [0, 1]) assert_array_equal(test, [2, 3]) train, test = next(splits) assert_array_equal(train, [0, 1, 2, 3]) assert_array_equal(test, [4])
def test_default_input(self): X = range(3) foo = GapRollForward() assert_equal(foo.get_n_splits(X), 3) splits = foo.split(X) train, test = next(splits) assert_array_equal(train, []) assert_array_equal(test, [0]) train, test = next(splits) assert_array_equal(train, [0]) assert_array_equal(test, [1]) train, test = next(splits) assert_array_equal(train, [0, 1]) assert_array_equal(test, [2])
def test_invalid_input(self): X = range(10) with pytest.raises( ValueError, match="No valid splits for the input arguments." ): next(GapRollForward(min_train_size=3, gap_size=7).split(X))