def train_val_test_split( data: np.ndarray, labels: np.ndarray, train_size: Union[List, float, int] = 0.8, val_size: float = 0.1, stratified: bool = True, seed: int = 0 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Split the data into train, val and test sets. The size of the training set is set by the train_size parameter. All the remaining samples will be treated as a test set :param data: Data with the [SAMPLES, ...] dimensions :param labels: Vector with corresponding labels :param train_size: If float, should be between 0.0 and 1.0, if stratified = True, it represents percentage of each class to be extracted, If float and stratified = False, it represents percentage of the whole dataset to be extracted with samples drawn randomly, regardless of their class. If int and stratified = True, it represents number of samples to be drawn from each class. If int and stratified = False, it represents overall number of samples to be drawn regardless of their class, randomly. Defaults to 0.8 :param val_size: Should be between 0.0 and 1.0. Represents the percentage of each class from the training set to be extracted as a validation set, defaults to 0.1 :param stratified: Indicated whether the extracted training set should be stratified, defaults to True :param seed: Seed used for data shuffling :return: train_x, train_y, val_x, val_y, test_x, test_y :raises AssertionError: When wrong type is passed as train_size """ shuffle_arrays_together([data, labels], seed=seed) train_indices = _get_set_indices(train_size, labels, stratified) val_indices = _get_set_indices(val_size, labels[train_indices]) val_indices = train_indices[val_indices] test_indices = np.setdiff1d(np.arange(len(data)), train_indices) train_indices = np.setdiff1d(train_indices, val_indices) return data[train_indices], labels[train_indices], data[val_indices], \ labels[val_indices], data[test_indices], labels[test_indices]
def test_if_arrays_are_modified_in_place(self): array = np.arange(10) utils.shuffle_arrays_together([array]) assert not np.all(np.equal(array, np.arange(10)))
def test_if_throws_for_arrays_with_different_sizes(self): array1 = np.arange(10) array2 = np.arange(5) with pytest.raises(AssertionError): utils.shuffle_arrays_together([array1, array2])
def test_if_works_for_different_seeds(self, arrays, seed): utils.shuffle_arrays_together(arrays, seed) assert np.all([np.equal(x, arrays[0]) for x in arrays[1:]])
def test_if_shuffled_arrays_have_same_order(self, arrays): utils.shuffle_arrays_together(arrays) assert np.all([np.equal(x1, arrays[0]) for x1 in arrays[1:]])