コード例 #1
0
ファイル: preprocessing.py プロジェクト: zlwdghh/hypernet
def train_val_test_split(
    data: np.ndarray,
    labels: np.ndarray,
    train_size: Union[List, float, int] = 0.8,
    val_size: float = 0.1,
    stratified: bool = True,
    seed: int = 0
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
           np.ndarray]:
    """
    Split the data into train, val and test sets. The size of the training set
    is set by the train_size parameter. All the remaining samples will be
    treated as a test set

    :param data: Data with the [SAMPLES, ...] dimensions
    :param labels: Vector with corresponding labels
    :param train_size: If float, should be between 0.0 and 1.0,
                        if stratified = True, it represents percentage of each
                        class to be extracted,
                 If float and stratified = False, it represents percentage of the
                    whole dataset to be extracted with samples drawn randomly,
                    regardless of their class.
                 If int and stratified = True, it represents number of samples
                    to be drawn from each class.
                 If int and stratified = False, it represents overall number of
                    samples to be drawn regardless of their class, randomly.
                 Defaults to 0.8
    :param val_size: Should be between 0.0 and 1.0. Represents the percentage of
                     each class from the training set to be extracted as a
                     validation set, defaults to 0.1
    :param stratified: Indicated whether the extracted training set should be
                     stratified, defaults to True
    :param seed: Seed used for data shuffling
    :return: train_x, train_y, val_x, val_y, test_x, test_y
    :raises AssertionError: When wrong type is passed as train_size
    """
    shuffle_arrays_together([data, labels], seed=seed)
    train_indices = _get_set_indices(train_size, labels, stratified)
    val_indices = _get_set_indices(val_size, labels[train_indices])
    val_indices = train_indices[val_indices]
    test_indices = np.setdiff1d(np.arange(len(data)), train_indices)
    train_indices = np.setdiff1d(train_indices, val_indices)
    return data[train_indices], labels[train_indices], data[val_indices], \
           labels[val_indices], data[test_indices], labels[test_indices]
コード例 #2
0
 def test_if_arrays_are_modified_in_place(self):
     array = np.arange(10)
     utils.shuffle_arrays_together([array])
     assert not np.all(np.equal(array, np.arange(10)))
コード例 #3
0
 def test_if_throws_for_arrays_with_different_sizes(self):
     array1 = np.arange(10)
     array2 = np.arange(5)
     with pytest.raises(AssertionError):
         utils.shuffle_arrays_together([array1, array2])
コード例 #4
0
 def test_if_works_for_different_seeds(self, arrays, seed):
     utils.shuffle_arrays_together(arrays, seed)
     assert np.all([np.equal(x, arrays[0]) for x in arrays[1:]])
コード例 #5
0
 def test_if_shuffled_arrays_have_same_order(self, arrays):
     utils.shuffle_arrays_together(arrays)
     assert np.all([np.equal(x1, arrays[0]) for x1 in arrays[1:]])