예제 #1
0
def test_stream_dataloader_error():
    dataset = MyStream(100, error=True)
    sampler = StreamSampler(batch_size=4)
    dataloader = DataLoader(dataset, sampler)
    with pytest.raises(AssertionError, match=r".*tuple.*"):
        data_iter = iter(dataloader)
        next(data_iter)
예제 #2
0
def test_stream_dataloader_timeout(num_workers):
    dataset = MyStream(100, False, block=True)
    sampler = StreamSampler(batch_size=4)

    dataloader = DataLoader(
        dataset, sampler, num_workers=num_workers, timeout=2, preload=True
    )
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        next(data_iter)
예제 #3
0
def test_timeout_event(num_workers):
    def cb():
        return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2, ))))

    dataset = MyStream(100, block=True)
    sampler = StreamSampler(batch_size=4)

    dataloader = DataLoader(dataset,
                            sampler,
                            num_workers=num_workers,
                            timeout=2,
                            timeout_event=cb)
    for _, data in enumerate(dataloader):
        np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
        np.testing.assert_equal(data[1], np.ones(shape=(4, )))
        break
예제 #4
0
def test_stream_dataloader_timeout(num_workers):
    dataset = MyStream(100, False)
    sampler = StreamSampler(batch_size=4)

    class TimeoutTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            time.sleep(10)
            return input

    dataloader = DataLoader(dataset,
                            sampler,
                            TimeoutTransform(),
                            num_workers=num_workers,
                            timeout=5)
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        next(data_iter)
예제 #5
0
def test_stream_dataloader(batch, num_workers):
    dataset = MyStream(100, batch=batch)
    sampler = StreamSampler(batch_size=4)
    dataloader = DataLoader(
        dataset,
        sampler,
        Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
        num_workers=num_workers,
        preload=True,
    )

    check_set = set()

    for step, data in enumerate(dataloader):
        if step == 10:
            break
        assert data[0]._tuple_shape == (4, 3, 2, 2)
        assert data[1]._tuple_shape == (4,)
        for i in data[1]:
            assert i not in check_set
            check_set.add(i)