コード例 #1
0
def test_data_range():
    make_test_h5()

    # non-preloading
    ds = loaders.H5Dataset('tmp.h5')
    ds.set_data_range([2, 4])
    assert len(ds) == 2
    item = ds[0]
    assert np.all(item == 2 * np.ones([5, 5]))
    item2 = ds[1]
    assert np.all(item2 == 3 * np.ones([5, 5]))
    ds.close()

    # preloading
    ds = loaders.H5Dataset('tmp.h5')
    ds.preload()
    ds.set_data_range([2, 4])
    assert len(ds) == 2
    item = ds[0]
    assert np.all(item == 2 * np.ones([5, 5]))
    item2 = ds[1]
    assert np.all(item2 == 3 * np.ones([5, 5]))
    ds.close()

    rm_test_h5()
    return
コード例 #2
0
def test_clip():
    make_test_h5()
    ds = loaders.H5Dataset('tmp.h5', clip=(1, 5))
    item = ds[0]
    assert np.all(item == np.ones([5, 5]))
    ds.close()
    rm_test_h5()
    return
コード例 #3
0
def test_h5dataset_basics():
    make_test_h5()
    ds = loaders.H5Dataset('tmp.h5')
    item = ds[0]
    assert len(ds) == 4
    assert ds.shape == (5, 5)
    assert np.all(item == np.zeros([5, 5]))
    ds.close()
    rm_test_h5()
    return
コード例 #4
0
def test_preload():
    make_test_h5()
    ds = loaders.H5Dataset('tmp.h5')
    ds.preload()
    assert type(ds._data) == np.ndarray
    item = ds[0]
    assert len(ds) == 4
    assert ds.shape == (5, 5)
    assert np.all(item == np.zeros([5, 5]))
    ds.close()
    rm_test_h5()
    return
コード例 #5
0
def test_PDDL():

    make_test_h5()

    ds = loaders.H5Dataset('tmp.h5')
    ddl = loaders.PreloadingDDL(ds, 0, 2, batch_size=2)
    for i, b in enumerate(ddl):
        assert np.all(b[0].numpy() == i)
        assert np.all(b[1].numpy() == i + 1)
        assert b.shape == (2, 5, 5)
    ds.close()

    rm_test_h5()
    return
コード例 #6
0
def test_epoch_style_use():
    make_test_h5()
    ds = loaders.H5Dataset('tmp.h5')
    ddl = loaders.DistributedDataLoader(ds, 0, 2, batch_size=2)

    for epoch in range(5):
        for i, b in enumerate(ddl):
            assert np.all(b[0].numpy() == i)
            assert np.all(b[1].numpy() == i + 1)
            assert b.shape == (2, 5, 5)

    ds.close()
    rm_test_h5()
    return
コード例 #7
0
def test_integration():

    make_test_h5()

    ds = loaders.H5Dataset('tmp.h5')
    ddl = loaders.DistributedDataLoader(ds, 0, 2, batch_size=2)
    for i, b in enumerate(ddl):
        assert np.all(b[0].numpy() == i)
        assert np.all(b[1].numpy() == i + 1)
        assert b.shape == (2, 5, 5)
    ds.close()

    ds2 = loaders.H5Dataset('tmp.h5')
    ds2.preload()
    ddl2 = loaders.DistributedDataLoader(ds2, 0, 2, batch_size=2)
    for i, b in enumerate(ddl2):
        assert np.all(b[0].numpy() == i)
        assert np.all(b[1].numpy() == i + 1)
        assert b.shape == (2, 5, 5)
    ds2.close()

    rm_test_h5()
    return
コード例 #8
0
def test_slice():
    make_test_h5()

    # non-preloading
    ds = loaders.H5Dataset('tmp.h5')
    ds.set_data_range([1, 5])
    assert len(ds) == 4
    item = ds[0:2]
    assert np.all(item[0] == 1 * np.ones([5, 5]))
    assert np.all(item[1] == 2 * np.ones([5, 5]))
    ds.close()

    # preloading
    ds = loaders.H5Dataset('tmp.h5')
    ds.preload()
    ds.set_data_range([1, 5])
    assert len(ds) == 4
    item = ds[0:2]
    assert np.all(item[0] == 1 * np.ones([5, 5]))
    assert np.all(item[1] == 2 * np.ones([5, 5]))
    ds.close()

    rm_test_h5()
    return