Beispiel #1
0
    def test_too_many_folds(self, time_series):

        with pytest.raises(ValueError):
            for element in time_series_split(time_series,
                                             n_splits=(len(time_series) + 1),
                                             split_on='index'):
                pass
Beispiel #2
0
    def test_split_on_neither_time_nor_index(self, time_series):

        with pytest.raises(ValueError):
            for element in time_series_split(time_series,
                                             n_splits=5,
                                             split_on='abc'):
                pass
Beispiel #3
0
    def test_split_on_time_with_non_time_indexed_dataframe(
            self, non_time_index):
        # If the split_on is set to 'time' but index is not DateTime

        with pytest.raises(ValueError):
            for element in time_series_split(non_time_index,
                                             n_splits=5,
                                             split_on='time'):
                pass
Beispiel #4
0
    def test_split_on_time(self, time_series):
        record_count = []
        for fold_index in time_series_split(time_series,
                                            n_splits=5,
                                            split_on='time'):
            record_count.append(len(fold_index))
        correct_record_count = []
        for fold_index in self._correct_split_on_time_record_length(
                time_series):
            correct_record_count.append(len(fold_index))

        assert all(
            [a == b for a, b in zip(correct_record_count, record_count)])
    def test_period_index(self, period_index):
        record_count = []
        for fold_index in time_series_split(period_index,
                                            n_splits=5,
                                            split_on="time"):
            record_count.append(len(fold_index))
        correct_record_count = []
        for fold_index in self._correct_split_on_time_record_length(
                period_index):
            correct_record_count.append(len(fold_index))

        assert all(
            [a == b for a, b in zip(correct_record_count, record_count)])
Beispiel #6
0
    def test_split_on_index(self, time_series):
        splits = 5
        split_length = len(time_series) // splits
        fold_length = split_length
        length_list = []

        for fold_index in time_series_split(time_series,
                                            n_splits=splits,
                                            split_on='index'):
            length_list.append(len(fold_index))
        for index_length in range(len(length_list) - 1):
            assert fold_length == length_list[index_length]
            fold_length += split_length
        last_fold_length = len(time_series)

        assert last_fold_length == length_list[-1]