def test_stream_dataloader_error(): dataset = MyStream(100, error=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader(dataset, sampler) with pytest.raises(AssertionError, match=r".*tuple.*"): data_iter = iter(dataloader) next(data_iter)
def test_stream_dataloader_timeout(num_workers): dataset = MyStream(100, False, block=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, sampler, num_workers=num_workers, timeout=2, preload=True ) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) next(data_iter)
def test_timeout_event(num_workers): def cb(): return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2, )))) dataset = MyStream(100, block=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb) for _, data in enumerate(dataloader): np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) np.testing.assert_equal(data[1], np.ones(shape=(4, ))) break
def test_stream_dataloader_timeout(num_workers): dataset = MyStream(100, False) sampler = StreamSampler(batch_size=4) class TimeoutTransform(Transform): def __init__(self): pass def apply(self, input): time.sleep(10) return input dataloader = DataLoader(dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) next(data_iter)
def test_stream_dataloader(batch, num_workers): dataset = MyStream(100, batch=batch) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, sampler, Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), num_workers=num_workers, preload=True, ) check_set = set() for step, data in enumerate(dataloader): if step == 10: break assert data[0]._tuple_shape == (4, 3, 2, 2) assert data[1]._tuple_shape == (4,) for i in data[1]: assert i not in check_set check_set.add(i)