예제 #1
0
def _multi_instances_parallel_dataloader_worker():
    dataset = init_dataset()

    for divide_flag in [True, False]:
        train_dataloader = DataLoader(
            dataset,
            sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
            num_workers=2,
            divide=divide_flag,
            preload=True,
        )
        val_dataloader = DataLoader(
            dataset,
            sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
            num_workers=2,
            divide=divide_flag,
            preload=True,
        )
        for idx, (data, label) in enumerate(train_dataloader):
            assert data._tuple_shape == (4, 1, 32, 32)
            assert label._tuple_shape == (4,)
            if idx % 5 == 0:
                for val_data, val_label in val_dataloader:
                    assert val_data._tuple_shape == (10, 1, 32, 32)
                    assert val_label._tuple_shape == (10,)
예제 #2
0
def test_dataloader_parallel():
    # set max shared memory to 100M
    os.environ["MGE_PLASMA_MEMORY"] = "100000000"

    dataset = init_dataset()
    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        num_workers=2,
        divide=False,
        preload=True,
    )
    for (data, label) in dataloader:
        assert data._tuple_shape == (4, 1, 32, 32)
        assert label._tuple_shape == (4,)

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        num_workers=2,
        divide=True,
        preload=True,
    )
    for (data, label) in dataloader:
        assert data._tuple_shape == (4, 1, 32, 32)
        assert label._tuple_shape == (4,)
예제 #3
0
def test_stream_dataloader_error():
    dataset = MyStream(100, error=True)
    sampler = StreamSampler(batch_size=4)
    dataloader = DataLoader(dataset, sampler)
    with pytest.raises(AssertionError, match=r".*tuple.*"):
        data_iter = iter(dataloader)
        next(data_iter)
예제 #4
0
def test_dataloader_serial():
    dataset = init_dataset()
    dataloader = DataLoader(
        dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False)
    )
    for (data, label) in dataloader:
        assert data.shape == (4, 1, 32, 32)
        assert label.shape == (4,)
예제 #5
0
def test_OmniglotDataset_loader():
    ds = build_dataset()
    sampler = msp.SequentialSampler(dataset=ds, batch_size=32, drop_last=True)
    dataload = DataLoader(dataset=ds, sampler=sampler, num_workers=4)
    for im_s, lb_s, im_q, lb_q in dataload:
        im_s.shape  # [ 32 ,5 ,1 ,105, 105, 1]
        lb_s.shape  # (32, 5, 1, 5)
        break
예제 #6
0
def test_dataloader_init():
    dataset = init_dataset()
    with pytest.raises(ValueError):
        dataloader = DataLoader(dataset, num_workers=2, divide=True)
    with pytest.raises(ValueError):
        dataloader = DataLoader(dataset, num_workers=-1)
    with pytest.raises(ValueError):
        dataloader = DataLoader(dataset, timeout=-1)
    with pytest.raises(ValueError):
        dataloader = DataLoader(dataset, num_workers=0, divide=True)

    dataloader = DataLoader(dataset)
    assert isinstance(dataloader.sampler, SequentialSampler)
    assert isinstance(dataloader.transform, PseudoTransform)
    assert isinstance(dataloader.collator, Collator)

    dataloader = DataLoader(dataset,
                            sampler=RandomSampler(dataset,
                                                  batch_size=6,
                                                  drop_last=False))
    assert len(dataloader) == 17
    dataloader = DataLoader(dataset,
                            sampler=RandomSampler(dataset,
                                                  batch_size=6,
                                                  drop_last=True))
    assert len(dataloader) == 16
예제 #7
0
def test_stream_dataloader_timeout(num_workers):
    dataset = MyStream(100, False, block=True)
    sampler = StreamSampler(batch_size=4)

    dataloader = DataLoader(
        dataset, sampler, num_workers=num_workers, timeout=2, preload=True
    )
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        next(data_iter)
예제 #8
0
def build_dataloader(root=Path('/home/zqh/data/omniglot-py'),
                     nway=5,
                     kshot=1,
                     kquery=1,
                     batch_size=32):
    train_ds = OmniglotDataset(root, nway, kshot, kquery, mode='train')
    train_smp = msp.SequentialSampler(train_ds,
                                      drop_last=True,
                                      batch_size=batch_size)
    train_loader = DataLoader(train_ds, sampler=train_smp, num_workers=4)

    return train_loader
예제 #9
0
def test_timeout_event(num_workers):
    def cb():
        return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2, ))))

    dataset = MyStream(100, block=True)
    sampler = StreamSampler(batch_size=4)

    dataloader = DataLoader(dataset,
                            sampler,
                            num_workers=num_workers,
                            timeout=2,
                            timeout_event=cb)
    for _, data in enumerate(dataloader):
        np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
        np.testing.assert_equal(data[1], np.ones(shape=(4, )))
        break
예제 #10
0
def test_dataloader_parallel_worker_exception():
    dataset = init_dataset()

    class FakeErrorTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            y = x + 1
            return input

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        transform=FakeErrorTransform(),
        num_workers=2,
    )
    with pytest.raises(RuntimeError, match=r"worker.*died"):
        data_iter = iter(dataloader)
        batch_data = next(data_iter)
예제 #11
0
def test_stream_dataloader_timeout(num_workers):
    dataset = MyStream(100, False)
    sampler = StreamSampler(batch_size=4)

    class TimeoutTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            time.sleep(10)
            return input

    dataloader = DataLoader(dataset,
                            sampler,
                            TimeoutTransform(),
                            num_workers=num_workers,
                            timeout=5)
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        next(data_iter)
예제 #12
0
def test_dataloader_parallel_timeout():
    dataset = init_dataset()

    class TimeoutTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            time.sleep(10)
            return input

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        transform=TimeoutTransform(),
        num_workers=2,
        timeout=2,
    )
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        batch_data = next(data_iter)
예제 #13
0
def test_stream_dataloader(batch, num_workers):
    dataset = MyStream(100, batch=batch)
    sampler = StreamSampler(batch_size=4)
    dataloader = DataLoader(
        dataset,
        sampler,
        Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
        num_workers=num_workers,
        preload=True,
    )

    check_set = set()

    for step, data in enumerate(dataloader):
        if step == 10:
            break
        assert data[0]._tuple_shape == (4, 3, 2, 2)
        assert data[1]._tuple_shape == (4,)
        for i in data[1]:
            assert i not in check_set
            check_set.add(i)