Exemple #1
0
def test_distributive():

    size = 4
    n_dpts = 47
    ds = np.arange(n_dpts)

    ddl_1 = loaders.DistributedDataLoader(ds, 0, size, batch_size=1)
    ddl_2 = loaders.DistributedDataLoader(ds, 0, size, batch_size=2)

    for rank in range(size):

        ddl_1.rank = rank
        ddl_2.rank = rank

        st = 12 * rank
        sp = min(12 * (rank + 1), n_dpts)
        assert ddl_1.data_range == (st, sp)
        assert ddl_2.data_range == (st, sp)

        for i, b in enumerate(ddl_1):
            assert np.all(b.numpy() == ds[st + i])

        for i, b in enumerate(ddl_2):
            assert np.all(b.numpy() == ds[st + i * 2:st + (i + 1) * 2])

    return
Exemple #2
0
def test_pin_memory():
    n_dpts = 47
    ds = np.arange(n_dpts)
    ddl = loaders.DistributedDataLoader(ds,
                                        0,
                                        1,
                                        batch_size=1,
                                        pin_memory=True)
    for i, b in enumerate(ddl):
        assert np.all(b.numpy() == ds[i])
    return
Exemple #3
0
def test_epoch_style_use():
    make_test_h5()
    ds = loaders.H5Dataset('tmp.h5')
    ddl = loaders.DistributedDataLoader(ds, 0, 2, batch_size=2)

    for epoch in range(5):
        for i, b in enumerate(ddl):
            assert np.all(b[0].numpy() == i)
            assert np.all(b[1].numpy() == i + 1)
            assert b.shape == (2, 5, 5)

    ds.close()
    rm_test_h5()
    return
Exemple #4
0
def test_integration():

    make_test_h5()

    ds = loaders.H5Dataset('tmp.h5')
    ddl = loaders.DistributedDataLoader(ds, 0, 2, batch_size=2)
    for i, b in enumerate(ddl):
        assert np.all(b[0].numpy() == i)
        assert np.all(b[1].numpy() == i + 1)
        assert b.shape == (2, 5, 5)
    ds.close()

    ds2 = loaders.H5Dataset('tmp.h5')
    ds2.preload()
    ddl2 = loaders.DistributedDataLoader(ds2, 0, 2, batch_size=2)
    for i, b in enumerate(ddl2):
        assert np.all(b[0].numpy() == i)
        assert np.all(b[1].numpy() == i + 1)
        assert b.shape == (2, 5, 5)
    ds2.close()

    rm_test_h5()
    return
Exemple #5
0
def test_batch_size():

    rank = 0
    size = 1
    n_dpts = 47

    ds = np.random.randn(n_dpts, 2, 3)
    ddl = loaders.DistributedDataLoader(ds, rank, size, batch_size=3)

    assert len(ddl) == n_dpts // 3 + 1  # +1 for final batch

    for i, b in enumerate(ddl):
        assert np.all(b.numpy() == ds[i * 3:(i + 1) * 3])

    ds = np.random.randn(n_dpts, 2, 3)
    ddl = loaders.DistributedDataLoader(ds,
                                        rank,
                                        size,
                                        batch_size=3,
                                        drop_last=True)
    assert len(ddl) == n_dpts // 3

    return
Exemple #6
0
def test_basic():

    rank = 0
    size = 1
    n_dpts = 47

    ds = np.random.randn(n_dpts, 2, 3)
    ddl = loaders.DistributedDataLoader(ds, rank, size, batch_size=1)

    assert ddl.data_range == (0, n_dpts)
    assert ddl.n_iter == n_dpts
    assert len(ddl) == n_dpts

    for i, b in enumerate(ddl):
        assert np.all(b.numpy() == ds[i])

    return