Ejemplo n.º 1
0
def test_load_data(h5file):
    @custom_preprocessor
    class PlusOnePreprocessor(BasePreprocessor):
        def transform(self, x, y):
            return x + 1, y + 1

    actual_dr = load_data(load_json_config(
        read_file('tests/json/dataset_config.json')))

    expected_dr = HDF5Reader(
        filename=h5file,
        batch_size=8,
        preprocessors=PlusOnePreprocessor(),
        x_name='input',
        y_name='target',
        train_folds=[0, 1, 2],
        val_folds=[3],
        test_folds=[4, 5]
    )

    assert isinstance(actual_dr, HDF5Reader)
    assert isinstance(actual_dr.preprocessors[0], PlusOnePreprocessor)

    actual_train_data = actual_dr.train_generator
    actual_val_data = actual_dr.val_generator
    actual_test_data = actual_dr.test_generator

    expected_train_data = expected_dr.train_generator
    expected_val_data = expected_dr.val_generator
    expected_test_data = expected_dr.test_generator

    check_equal_data_generator(actual_train_data, expected_train_data)
    check_equal_data_generator(actual_val_data, expected_val_data)
    check_equal_data_generator(actual_test_data, expected_test_data)
Ejemplo n.º 2
0
def test_hdf5_dr_generator():
    batch_size = 8
    dr = HDF5Reader(VALID_H5_FILE,
                    batch_size=batch_size,
                    x_name='input',
                    y_name='target',
                    train_folds=[0, 1, 2, 3],
                    val_folds=[4],
                    test_folds=[5, 6])

    expected_train_input = np.reshape(np.arange(400 * 25), (400, 5, 5))
    expected_train_target = np.array([1, 2, 3, 4, 5] * 80)

    total_batch = dr.train_generator.total_batch
    data_gen = dr.train_generator.generate()
    end = 0

    for i, (input_data, target) in enumerate(data_gen):
        if i >= total_batch:
            break

        start = end
        end = start + batch_size

        if end % 100 < batch_size:
            end -= end % 100

        print(start, end, expected_train_target)
        print(target)

        assert np.all(expected_train_target[start:end] == target)
        assert np.all(expected_train_input[start:end] == input_data)
Ejemplo n.º 3
0
def test_hdf5_dr_constructor_h5file_no_groups():
    dr = HDF5Reader(INVALID_H5_FILE_NO_GROUPS,
                    batch_size=8,
                    x_name='input',
                    y_name='target')
    with pytest.raises(RuntimeError):
        dr.train_generator
Ejemplo n.º 4
0
def test_hdf5_dr_constructor_invalid_yname():
    with pytest.raises(RuntimeError):
        dr = HDF5Reader(VALID_H5_FILE,
                        batch_size=8,
                        x_name='input',
                        y_name='y')
        dr.train_generator
Ejemplo n.º 5
0
def test_hdf5_dr_total_batch():
    dr = HDF5Reader(VALID_H5_FILE,
                    batch_size=8,
                    x_name='input',
                    y_name='target',
                    train_folds=[0, 1, 2, 3],
                    val_folds=[4],
                    test_folds=[5, 6])

    assert dr.train_generator.total_batch == int(np.ceil(100 / 8) * 4)
    assert dr.val_generator.total_batch == int(np.ceil(100 / 8))
    assert dr.test_generator.total_batch == int(np.ceil(100 / 8) * 2)
Ejemplo n.º 6
0
def test_hdf5_dr_constructor_no_prefix():
    dr = HDF5Reader(VALID_H5_FILE_M,
                    batch_size=8,
                    x_name='input',
                    y_name='target',
                    fold_prefix=None)
    assert dr.train_folds == [0]
    assert dr.val_folds == [1]
    assert dr.test_folds == [2]
    dr.train_generator
    dr.test_generator
    dr.val_generator
Ejemplo n.º 7
0
def test_hdf5_dr_constructor():
    dr = HDF5Reader(VALID_H5_FILE,
                    batch_size=8,
                    x_name='input',
                    y_name='target',
                    fold_prefix='fold')
    assert dr.train_folds == ['fold_0']
    assert dr.val_folds == ['fold_1']
    assert dr.test_folds == ['fold_2']
    dr.train_generator
    dr.test_generator
    dr.val_generator
Ejemplo n.º 8
0
def test_hdf5_dr_original_val():
    batch_size = 8
    dr = HDF5Reader(VALID_H5_FILE,
                    batch_size=batch_size,
                    x_name='input',
                    y_name='target',
                    train_folds=[0, 1, 2, 3],
                    val_folds=[4, 5],
                    test_folds=[6, 7, 8])
    original_val = dr.original_val
    assert len(original_val.keys()) == 2
    assert 'input' in original_val
    assert original_val['input'].shape[0] == 200
    assert 'target' in original_val
    assert original_val['target'].shape[0] == 200
Ejemplo n.º 9
0
def test_hdf5_dr_constructor_invalid_preprocessor():
    dr = HDF5Reader(VALID_H5_FILE,
                    batch_size=8,
                    preprocessors=[1],
                    x_name='input',
                    y_name='target')

    with pytest.raises(ValueError):
        dr.train_generator

    with pytest.raises(ValueError):
        dr.val_generator

    with pytest.raises(ValueError):
        dr.test_generator
Ejemplo n.º 10
0
def test_hdf5_dr_original_val_more_column():
    batch_size = 8
    dr = HDF5Reader(INVALID_H5_FILE_GROUPS,
                    batch_size=batch_size,
                    x_name='x',
                    y_name='y',
                    train_folds=[6, 7],
                    val_folds=[8],
                    test_folds=[9])
    original_val = dr.original_val
    assert len(original_val.keys()) == 3
    assert 'x' in original_val
    assert original_val['x'].shape[0] == 100
    assert 'y' in original_val
    assert original_val['y'].shape[0] == 100
    assert 'z' in original_val
    assert original_val['z'].shape[0] == 100
Ejemplo n.º 11
0
def test_hdf5_dr_constructor_invalid_h5_structure():
    # 0 - 2: input, target
    # 3 - 5: x, y
    # 6 - 10: x, y, z
    dr = HDF5Reader(INVALID_H5_FILE_GROUPS,
                    batch_size=8,
                    x_name='x',
                    y_name='y',
                    train_folds=[3, 4, 5, 6],
                    val_folds=[7],
                    test_folds=[8, 9])

    with pytest.raises(RuntimeError):
        dr.train_generator

    dr.test_generator
    dr.val_generator
Ejemplo n.º 12
0
def test_hdf5_dr_constructor_invalid_datasetname():
    # 0 - 2: input, target
    # 3 - 5: x, y
    # 6 - 10: x, y, z
    dr = HDF5Reader(INVALID_H5_FILE_GROUPS,
                    batch_size=8,
                    x_name='input',
                    y_name='target',
                    train_folds=[0, 1, 2],
                    val_folds=[3, 4, 5],
                    test_folds=[6, 7, 8])

    dr.train_generator
    with pytest.raises(RuntimeError):
        dr.val_generator

    with pytest.raises(RuntimeError):
        dr.test_generator
Ejemplo n.º 13
0
def test_hdf5_dr_generator_multipreprocessors():
    batch_size = 8

    @custom_preprocessor
    class PlusOnePreprocessor(BasePreprocessor):
        def transform(x, y):
            return x + 1, y + 1

    dr = HDF5Reader(VALID_H5_FILE,
                    batch_size=batch_size,
                    x_name='input',
                    y_name='target',
                    train_folds=[0, 1, 2, 3],
                    val_folds=[4],
                    test_folds=[5, 6],
                    preprocessors=[PlusOnePreprocessor, PlusOnePreprocessor])

    expected_train_input = np.reshape(np.arange(2, 400 * 25 + 2), (400, 5, 5))
    expected_train_target = np.array([3, 4, 5, 6, 7] * 80)

    total_batch = dr.train_generator.total_batch
    data_gen = dr.train_generator.generate()
    end = 0

    for i, (input_data, target) in enumerate(data_gen):
        if i >= total_batch:
            break

        start = end
        end = start + batch_size

        if end % 100 < batch_size:
            end -= end % 100

        print(start, end, expected_train_target)
        print(target)

        assert np.all(expected_train_target[start:end] == target)
        assert np.all(expected_train_input[start:end] == input_data)
Ejemplo n.º 14
0
def test_hdf5_dr_constructor_invalid_xname():
    dr = HDF5Reader(VALID_H5_FILE, batch_size=8, x_name='x', y_name='target')

    with pytest.raises(RuntimeError):
        dr.train_generator