def test_update_dataloader_with_multiprocessing_context(): """This test verifies that replace_sampler conserves multiprocessing context.""" train = RandomDataset(32, 64) context = "spawn" train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) trainer = Trainer() new_data_loader = trainer._update_dataloader( train, SequentialSampler(train.dataset)) assert new_data_loader.multiprocessing_context == train.multiprocessing_context
def test_dataloaders_with_missing_keyword_arguments(): trainer = Trainer() ds = RandomDataset(10, 20) class TestDataLoader(DataLoader): def __init__(self, dataset): super().__init__(dataset) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) match = escape( "missing arguments are ['batch_sampler', 'sampler', 'shuffle']") with pytest.raises(MisconfigurationException, match=match): trainer._update_dataloader(loader, sampler, mode="fit") match = escape( "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']" ) with pytest.raises(MisconfigurationException, match=match): trainer._update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, dataset, *args, **kwargs): super().__init__(dataset) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) trainer._update_dataloader(loader, sampler, mode="fit") trainer._update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, *foo, **bar): super().__init__(*foo, **bar) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) trainer._update_dataloader(loader, sampler, mode="fit") trainer._update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, num_feat, dataset, *args, shuffle=False): self.num_feat = num_feat super().__init__(dataset) loader = TestDataLoader(1, ds) sampler = SequentialSampler(ds) match = escape("missing arguments are ['batch_sampler', 'sampler']") with pytest.raises(MisconfigurationException, match=match): trainer._update_dataloader(loader, sampler, mode="fit") match = escape( "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']" ) with pytest.raises(MisconfigurationException, match=match): trainer._update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, num_feat, dataset, **kwargs): self.feat_num = num_feat super().__init__(dataset) loader = TestDataLoader(1, ds) sampler = SequentialSampler(ds) match = escape("missing attributes are ['num_feat']") with pytest.raises(MisconfigurationException, match=match): trainer._update_dataloader(loader, sampler, mode="fit") match = escape("missing attributes are ['num_feat']") with pytest.raises(MisconfigurationException, match=match): trainer._update_dataloader(loader, sampler, mode="predict")
def test_update_dataloader_raises(): trainer = Trainer() with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"): trainer._update_dataloader(object(), object(), mode="fit")