def _multi_instances_parallel_dataloader_worker(): dataset = init_dataset() for divide_flag in [True, False]: train_dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2, divide=divide_flag, preload=True, ) val_dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=10, drop_last=False), num_workers=2, divide=divide_flag, preload=True, ) for idx, (data, label) in enumerate(train_dataloader): assert data._tuple_shape == (4, 1, 32, 32) assert label._tuple_shape == (4,) if idx % 5 == 0: for val_data, val_label in val_dataloader: assert val_data._tuple_shape == (10, 1, 32, 32) assert val_label._tuple_shape == (10,)
def test_dataloader_parallel(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" dataset = init_dataset() dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2, divide=False, preload=True, ) for (data, label) in dataloader: assert data._tuple_shape == (4, 1, 32, 32) assert label._tuple_shape == (4,) dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2, divide=True, preload=True, ) for (data, label) in dataloader: assert data._tuple_shape == (4, 1, 32, 32) assert label._tuple_shape == (4,)
def test_stream_dataloader_error(): dataset = MyStream(100, error=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader(dataset, sampler) with pytest.raises(AssertionError, match=r".*tuple.*"): data_iter = iter(dataloader) next(data_iter)
def test_dataloader_serial(): dataset = init_dataset() dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False) ) for (data, label) in dataloader: assert data.shape == (4, 1, 32, 32) assert label.shape == (4,)
def test_OmniglotDataset_loader(): ds = build_dataset() sampler = msp.SequentialSampler(dataset=ds, batch_size=32, drop_last=True) dataload = DataLoader(dataset=ds, sampler=sampler, num_workers=4) for im_s, lb_s, im_q, lb_q in dataload: im_s.shape # [ 32 ,5 ,1 ,105, 105, 1] lb_s.shape # (32, 5, 1, 5) break
def test_dataloader_init(): dataset = init_dataset() with pytest.raises(ValueError): dataloader = DataLoader(dataset, num_workers=2, divide=True) with pytest.raises(ValueError): dataloader = DataLoader(dataset, num_workers=-1) with pytest.raises(ValueError): dataloader = DataLoader(dataset, timeout=-1) with pytest.raises(ValueError): dataloader = DataLoader(dataset, num_workers=0, divide=True) dataloader = DataLoader(dataset) assert isinstance(dataloader.sampler, SequentialSampler) assert isinstance(dataloader.transform, PseudoTransform) assert isinstance(dataloader.collator, Collator) dataloader = DataLoader(dataset, sampler=RandomSampler(dataset, batch_size=6, drop_last=False)) assert len(dataloader) == 17 dataloader = DataLoader(dataset, sampler=RandomSampler(dataset, batch_size=6, drop_last=True)) assert len(dataloader) == 16
def test_stream_dataloader_timeout(num_workers): dataset = MyStream(100, False, block=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, sampler, num_workers=num_workers, timeout=2, preload=True ) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) next(data_iter)
def build_dataloader(root=Path('/home/zqh/data/omniglot-py'), nway=5, kshot=1, kquery=1, batch_size=32): train_ds = OmniglotDataset(root, nway, kshot, kquery, mode='train') train_smp = msp.SequentialSampler(train_ds, drop_last=True, batch_size=batch_size) train_loader = DataLoader(train_ds, sampler=train_smp, num_workers=4) return train_loader
def test_timeout_event(num_workers): def cb(): return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2, )))) dataset = MyStream(100, block=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb) for _, data in enumerate(dataloader): np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) np.testing.assert_equal(data[1], np.ones(shape=(4, ))) break
def test_dataloader_parallel_worker_exception(): dataset = init_dataset() class FakeErrorTransform(Transform): def __init__(self): pass def apply(self, input): y = x + 1 return input dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), transform=FakeErrorTransform(), num_workers=2, ) with pytest.raises(RuntimeError, match=r"worker.*died"): data_iter = iter(dataloader) batch_data = next(data_iter)
def test_stream_dataloader_timeout(num_workers): dataset = MyStream(100, False) sampler = StreamSampler(batch_size=4) class TimeoutTransform(Transform): def __init__(self): pass def apply(self, input): time.sleep(10) return input dataloader = DataLoader(dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) next(data_iter)
def test_dataloader_parallel_timeout(): dataset = init_dataset() class TimeoutTransform(Transform): def __init__(self): pass def apply(self, input): time.sleep(10) return input dataloader = DataLoader( dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), transform=TimeoutTransform(), num_workers=2, timeout=2, ) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) batch_data = next(data_iter)
def test_stream_dataloader(batch, num_workers): dataset = MyStream(100, batch=batch) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, sampler, Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), num_workers=num_workers, preload=True, ) check_set = set() for step, data in enumerate(dataloader): if step == 10: break assert data[0]._tuple_shape == (4, 3, 2, 2) assert data[1]._tuple_shape == (4,) for i in data[1]: assert i not in check_set check_set.add(i)