def _three_way_split(splitter: KFold, X, y: Optional = None, groups: Optional = None) -> Generator: """A modified version of BaseCrossValidator.split(). Yields (K-2/1/1) train/val/test splits. """ X, y, groups = indexable(X, y, groups) indices = np.arange(_num_samples(X)) test_masks_it = splitter._iter_test_masks(X, y, groups) first_mask = last_mask = next(test_masks_it) for test_mask in test_masks_it: train_index = indices[np.logical_not( np.logical_or(test_mask, last_mask))] val_index = indices[last_mask] test_index = indices[test_mask] yield train_index, val_index, test_index last_mask = test_mask # last fold test_mask = first_mask train_index = indices[np.logical_not( np.logical_or(test_mask, last_mask))] val_index = indices[last_mask] test_index = indices[test_mask] yield train_index, val_index, test_index