def test_too_many_folds(self, time_series): with pytest.raises(ValueError): for element in time_series_split(time_series, n_splits=(len(time_series) + 1), split_on='index'): pass
def test_split_on_neither_time_nor_index(self, time_series): with pytest.raises(ValueError): for element in time_series_split(time_series, n_splits=5, split_on='abc'): pass
def test_split_on_time_with_non_time_indexed_dataframe( self, non_time_index): # If the split_on is set to 'time' but index is not DateTime with pytest.raises(ValueError): for element in time_series_split(non_time_index, n_splits=5, split_on='time'): pass
def test_split_on_time(self, time_series): record_count = [] for fold_index in time_series_split(time_series, n_splits=5, split_on='time'): record_count.append(len(fold_index)) correct_record_count = [] for fold_index in self._correct_split_on_time_record_length( time_series): correct_record_count.append(len(fold_index)) assert all( [a == b for a, b in zip(correct_record_count, record_count)])
def test_period_index(self, period_index): record_count = [] for fold_index in time_series_split(period_index, n_splits=5, split_on="time"): record_count.append(len(fold_index)) correct_record_count = [] for fold_index in self._correct_split_on_time_record_length( period_index): correct_record_count.append(len(fold_index)) assert all( [a == b for a, b in zip(correct_record_count, record_count)])
def test_split_on_index(self, time_series): splits = 5 split_length = len(time_series) // splits fold_length = split_length length_list = [] for fold_index in time_series_split(time_series, n_splits=splits, split_on='index'): length_list.append(len(fold_index)) for index_length in range(len(length_list) - 1): assert fold_length == length_list[index_length] fold_length += split_length last_fold_length = len(time_series) assert last_fold_length == length_list[-1]