Ejemplo n.º 1
0
def test_get_batch_match_frame_len():
    size = (64, 64)
    sdg = SpatialDataGenerator()
    sdg.source = 'data/small.tif'
    sdg.width, sdg.height = size
    df = sdg.regular_grid(*size)

    gen = sdg.flow_from_dataframe(df, batch_size=3)
    count = sum([batch.shape[0] for batch in gen])
    assert count == len(df)
Ejemplo n.º 2
0
def test_sample_size():
    size = (64, 64)
    sdg = SpatialDataGenerator()
    sdg.source = 'data/small.tif'
    sdg.width, sdg.height = size
    df = sdg.regular_grid(*size)
    gen = sdg.flow_from_dataframe(df)
    arr = next(gen)

    assert len(arr.shape) == 4
    assert arr.shape[0] == min(sdg.batch_size, len(df))
    assert arr.shape[-2] == size[0] and arr.shape[-1] == size[1]
Ejemplo n.º 3
0
def test_preprocess_add_array():
    def pre(arr):
        return np.stack((arr, arr / 10))

    size = (64, 64)
    sdg = SpatialDataGenerator()
    sdg.source = 'data/small.tif'
    sdg.indexes = 1
    sdg.width, sdg.height = size
    df = sdg.regular_grid(*size)

    sdg.add_preprocess_callback('pre', pre)
    arr = next(sdg.flow_from_dataframe(df))
    assert len(arr.shape) == 4
    assert arr.shape[0] == min(sdg.batch_size, len(df))
    assert arr.shape[1] == 2
    assert arr.shape[-2] == size[0] and arr.shape[-1] == size[1]
Ejemplo n.º 4
0
def test_preprocess_modify_array():
    def pre(arr, maxval):
        return arr / maxval

    size = (64, 64)
    sdg = SpatialDataGenerator()
    sdg.source = 'data/small.tif'
    sdg.indexes = 1
    sdg.width, sdg.height = size
    df = sdg.regular_grid(*size)
    df['max'] = [a.max() for a in sdg.flow_from_dataframe(df, batch_size=1)]

    sdg.add_preprocess_callback('normalize', pre, df['max'].max())
    arr = next(sdg.flow_from_dataframe(df))
    assert len(arr.shape) == 3
    assert arr.shape[0] == min(sdg.batch_size, len(df))
    assert arr.shape[-2] == size[0] and arr.shape[-1] == size[1]
    assert arr.max() <= 1.0
Ejemplo n.º 5
0
def test_random_grid():
    dg = SpatialDataGenerator()
    dg.source = 'data/small.tif'
    dg.width, dg.height = 64, 64
    df = dg.random_grid(100)
    assert len(df) == 100
Ejemplo n.º 6
0
def test_regular_grid():
    dg = SpatialDataGenerator()
    dg.width, dg.height = 64, 64
    dg.source = 'data/small.tif'
    df = dg.regular_grid()
    assert len(df) > 0