def test_data_range(): make_test_h5() # non-preloading ds = loaders.H5Dataset('tmp.h5') ds.set_data_range([2, 4]) assert len(ds) == 2 item = ds[0] assert np.all(item == 2 * np.ones([5, 5])) item2 = ds[1] assert np.all(item2 == 3 * np.ones([5, 5])) ds.close() # preloading ds = loaders.H5Dataset('tmp.h5') ds.preload() ds.set_data_range([2, 4]) assert len(ds) == 2 item = ds[0] assert np.all(item == 2 * np.ones([5, 5])) item2 = ds[1] assert np.all(item2 == 3 * np.ones([5, 5])) ds.close() rm_test_h5() return
def test_clip(): make_test_h5() ds = loaders.H5Dataset('tmp.h5', clip=(1, 5)) item = ds[0] assert np.all(item == np.ones([5, 5])) ds.close() rm_test_h5() return
def test_h5dataset_basics(): make_test_h5() ds = loaders.H5Dataset('tmp.h5') item = ds[0] assert len(ds) == 4 assert ds.shape == (5, 5) assert np.all(item == np.zeros([5, 5])) ds.close() rm_test_h5() return
def test_preload(): make_test_h5() ds = loaders.H5Dataset('tmp.h5') ds.preload() assert type(ds._data) == np.ndarray item = ds[0] assert len(ds) == 4 assert ds.shape == (5, 5) assert np.all(item == np.zeros([5, 5])) ds.close() rm_test_h5() return
def test_PDDL(): make_test_h5() ds = loaders.H5Dataset('tmp.h5') ddl = loaders.PreloadingDDL(ds, 0, 2, batch_size=2) for i, b in enumerate(ddl): assert np.all(b[0].numpy() == i) assert np.all(b[1].numpy() == i + 1) assert b.shape == (2, 5, 5) ds.close() rm_test_h5() return
def test_epoch_style_use(): make_test_h5() ds = loaders.H5Dataset('tmp.h5') ddl = loaders.DistributedDataLoader(ds, 0, 2, batch_size=2) for epoch in range(5): for i, b in enumerate(ddl): assert np.all(b[0].numpy() == i) assert np.all(b[1].numpy() == i + 1) assert b.shape == (2, 5, 5) ds.close() rm_test_h5() return
def test_integration(): make_test_h5() ds = loaders.H5Dataset('tmp.h5') ddl = loaders.DistributedDataLoader(ds, 0, 2, batch_size=2) for i, b in enumerate(ddl): assert np.all(b[0].numpy() == i) assert np.all(b[1].numpy() == i + 1) assert b.shape == (2, 5, 5) ds.close() ds2 = loaders.H5Dataset('tmp.h5') ds2.preload() ddl2 = loaders.DistributedDataLoader(ds2, 0, 2, batch_size=2) for i, b in enumerate(ddl2): assert np.all(b[0].numpy() == i) assert np.all(b[1].numpy() == i + 1) assert b.shape == (2, 5, 5) ds2.close() rm_test_h5() return
def test_slice(): make_test_h5() # non-preloading ds = loaders.H5Dataset('tmp.h5') ds.set_data_range([1, 5]) assert len(ds) == 4 item = ds[0:2] assert np.all(item[0] == 1 * np.ones([5, 5])) assert np.all(item[1] == 2 * np.ones([5, 5])) ds.close() # preloading ds = loaders.H5Dataset('tmp.h5') ds.preload() ds.set_data_range([1, 5]) assert len(ds) == 4 item = ds[0:2] assert np.all(item[0] == 1 * np.ones([5, 5])) assert np.all(item[1] == 2 * np.ones([5, 5])) ds.close() rm_test_h5() return