Beispiel #1
0
def test_update_dataloader_with_multiprocessing_context():
    """This test verifies that replace_sampler conserves multiprocessing context."""
    train = RandomDataset(32, 64)
    context = "spawn"
    train = DataLoader(train,
                       batch_size=32,
                       num_workers=2,
                       multiprocessing_context=context,
                       shuffle=True)
    trainer = Trainer()
    new_data_loader = trainer._update_dataloader(
        train, SequentialSampler(train.dataset))
    assert new_data_loader.multiprocessing_context == train.multiprocessing_context
Beispiel #2
0
def test_dataloaders_with_missing_keyword_arguments():
    trainer = Trainer()
    ds = RandomDataset(10, 20)

    class TestDataLoader(DataLoader):
        def __init__(self, dataset):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    match = escape(
        "missing arguments are ['batch_sampler', 'sampler', 'shuffle']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer._update_dataloader(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        trainer._update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, dataset, *args, **kwargs):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    trainer._update_dataloader(loader, sampler, mode="fit")
    trainer._update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, *foo, **bar):
            super().__init__(*foo, **bar)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    trainer._update_dataloader(loader, sampler, mode="fit")
    trainer._update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, *args, shuffle=False):
            self.num_feat = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing arguments are ['batch_sampler', 'sampler']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer._update_dataloader(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        trainer._update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, **kwargs):
            self.feat_num = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer._update_dataloader(loader, sampler, mode="fit")
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer._update_dataloader(loader, sampler, mode="predict")
Beispiel #3
0
def test_update_dataloader_raises():
    trainer = Trainer()
    with pytest.raises(ValueError,
                       match="needs to subclass `torch.utils.data.DataLoader"):
        trainer._update_dataloader(object(), object(), mode="fit")