Пример #1
0
def test_has_len():
    assert has_len(DataLoader(RandomDataset(1, 1)))

    with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
        assert has_len(DataLoader(RandomDataset(0, 0)))

    assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_boring_lite_model_ddp(precision, strategy, devices, accelerator,
                               tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 4))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir)

    lite_model_state_dict = model.state_dict()

    for w_pure, w_lite in zip(state_dict.values(),
                              lite_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 4))
    model = BoringModel()
    run(lite.global_rank, model, train_dataloader, num_epochs, precision,
        accelerator, tmpdir)
    pure_model_state_dict = model.state_dict()

    for w_pure, w_lite in zip(pure_model_state_dict.values(),
                              lite_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
 def predict_dataloader(self):
     return CombinedLoader({
         "a":
         DataLoader(RandomDataset(32, 8), batch_size=2),
         "b":
         DataLoader(RandomDataset(32, 8), batch_size=4),
     })
 def train_dataloader(self):
     if self.trigger_stop_iteration:
         return DataLoader(
             RandomDataset(BATCH_SIZE,
                           2 * EXPECT_NUM_BATCHES_PROCESSED))
     return DataLoader(
         RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))
Пример #5
0
 def test_dataloader(self):
     if suffix:
         return [
             torch.utils.data.DataLoader(RandomDataset(32, 64)),
             torch.utils.data.DataLoader(RandomDataset(32, 64)),
         ]
     return super().test_dataloader()
Пример #6
0
def test_model_nohparams_train_test(tmpdir, cls):
    """Test models that do not take any argument in init."""

    model = cls()
    trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)

    train_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
    trainer.fit(model, train_loader)

    test_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
    trainer.test(dataloaders=test_loader)
Пример #7
0
def test_has_len_all_rank():
    trainer = Trainer(fast_dev_run=True)
    model = BoringModel()

    with pytest.warns(
            UserWarning,
            match="Total length of `DataLoader` across ranks is zero."):
        assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)),
                                 trainer.strategy, model)

    assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy,
                             model)
Пример #8
0
def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
    class TestModel(BoringModel):
        def on_train_epoch_end(self):
            assert all(v == 3 for v in self.trainer.callback_metrics.values())

        def on_validation_epoch_end(self):
            assert all(v == 3 for v in self.trainer.callback_metrics.values())

        def on_train_batch_start(self, batch, batch_idx):
            self.log("on_train_batch_start",
                     1.0,
                     on_step=False,
                     on_epoch=True,
                     reduce_fx="sum")

        def on_train_batch_end(self, outputs, batch, batch_idx):
            self.log("on_train_batch_end",
                     1.0,
                     on_step=False,
                     on_epoch=True,
                     reduce_fx="sum")

        def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
            self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

        def on_validation_batch_end(self, outputs, batch, batch_idx,
                                    dataloader_idx):
            self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

        def training_epoch_end(self, *_) -> None:
            self.log("training_epoch_end", 3.0, reduce_fx="mean")
            assert self.trainer._results[
                "training_epoch_end.training_epoch_end"].value == 3.0

        def validation_epoch_end(self, *_) -> None:
            self.log("validation_epoch_end", 3.0, reduce_fx="mean")
            assert self.trainer._results[
                "validation_epoch_end.validation_epoch_end"].value == 3.0

    model = TestModel()
    trainer = Trainer(
        enable_progress_bar=False,
        limit_train_batches=3,
        limit_val_batches=3,
        num_sanity_val_steps=3,
        max_epochs=1,
    )
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
Пример #9
0
def test_get_len():
    assert get_len(DataLoader(RandomDataset(1, 1))) == 1

    value = get_len(DataLoader(RandomIterableDataset(1, 1)))

    assert isinstance(value, float)
    assert value == float("inf")
def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator,
                                     tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    checkpoint_path = lite.run(model,
                               train_dataloader,
                               num_epochs=num_epochs,
                               tmpdir=tmpdir)
    spawn_model_state_dict = torch.load(checkpoint_path)

    for w_pure, w_lite in zip(state_dict.values(),
                              spawn_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    model.load_state_dict(state_dict)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(find_free_network_port())
    mp.spawn(run,
             args=(model, train_dataloader, num_epochs, precision, accelerator,
                   tmpdir),
             nprocs=2)
    spawn_pure_model_state_dict = torch.load(
        os.path.join(tmpdir, "model_spawn.pt"))

    for w_pure, w_lite in zip(spawn_pure_model_state_dict.values(),
                              spawn_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
Пример #11
0
def test_trainer_predict_verify_config(tmpdir, datamodule):
    class TestModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(32, 2)

        def forward(self, x):
            return self.layer(x)

    class TestLightningDataModule(LightningDataModule):
        def __init__(self, dataloaders):
            super().__init__()
            self._dataloaders = dataloaders

        def test_dataloader(self):
            return self._dataloaders

        def predict_dataloader(self):
            return self._dataloaders

    data = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]
    if datamodule:
        data = TestLightningDataModule(data)

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir)
    results = trainer.predict(model, data)

    assert len(results) == 2
    assert results[0][0].shape == torch.Size([1, 2])
def test_boring_lite_model_single_device(precision, strategy, devices,
                                         accelerator, tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    lite.run(model, train_dataloader, num_epochs=num_epochs)
    lite_state_dict = model.state_dict()

    with precision_context(precision, accelerator):
        model.load_state_dict(state_dict)
        pure_state_dict = main(lite.to_device,
                               model,
                               train_dataloader,
                               num_epochs=num_epochs)

    state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device)
    for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()):
        assert not torch.equal(w_pure, w_lite)

    for w_pure, w_lite in zip(pure_state_dict.values(),
                              lite_state_dict.values()):
        assert torch.equal(w_pure, w_lite)
Пример #13
0
def test_has_iterable_dataset():
    assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))

    assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))

    class MockDatasetWithoutIterableDataset(RandomDataset):
        def __iter__(self):
            yield 1
            return self

    assert not has_iterable_dataset(
        DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
def test_update_dataloader_with_multiprocessing_context():
    """This test verifies that replace_sampler conserves multiprocessing context."""
    train = RandomDataset(32, 64)
    context = "spawn"
    train = DataLoader(train,
                       batch_size=32,
                       num_workers=2,
                       multiprocessing_context=context,
                       shuffle=True)
    new_data_loader = _update_dataloader(train,
                                         SequentialSampler(train.dataset))
    assert new_data_loader.multiprocessing_context == train.multiprocessing_context
def test_error_raised_with_insufficient_float_limit_train_dataloader():
    batch_size = 16
    dl = DataLoader(RandomDataset(32, batch_size * 9), batch_size=batch_size)
    trainer = Trainer(limit_train_batches=0.1)
    model = BoringModel()

    trainer._data_connector.attach_data(model=model, train_dataloaders=dl)
    with pytest.raises(
            MisconfigurationException,
            match=
            "Please increase the `limit_train_batches` argument. Try at least",
    ):
        trainer.reset_train_dataloader(model)
Пример #16
0
def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp: bool,
                                                   is_min_size_mode: bool,
                                                   use_combined_loader: bool):
    """When providing a CombinedLoader as the training data, it should be correctly receive the distributed
    samplers."""
    mode = "min_size" if is_min_size_mode else "max_size_cycle"
    dim = 3
    n1 = 8
    n2 = 6
    dataloader = {
        "a": DataLoader(RandomDataset(dim, n1), batch_size=1),
        "b": DataLoader(RandomDataset(dim, n2), batch_size=1),
    }
    if use_combined_loader:
        dataloader = CombinedLoader(dataloader, mode=mode)
    model = BoringModel()
    trainer = Trainer(
        strategy="ddp",
        accelerator="auto",
        devices="auto",
        replace_sampler_ddp=replace_sampler_ddp,
        multiple_trainloader_mode="max_size_cycle"
        if use_combined_loader else mode,
    )
    trainer._data_connector.attach_data(model=model,
                                        train_dataloaders=dataloader,
                                        val_dataloaders=None,
                                        datamodule=None)
    expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(
        n1, n2)
    expected_length_after_ddp = (math.ceil(expected_length_before_ddp /
                                           trainer.num_devices)
                                 if replace_sampler_ddp else
                                 expected_length_before_ddp)
    trainer.reset_train_dataloader(model=model)
    assert trainer.train_dataloader is not None
    assert isinstance(trainer.train_dataloader, CombinedLoader)
    assert trainer.train_dataloader.mode == mode
    assert trainer.num_training_batches == expected_length_after_ddp
Пример #17
0
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
    world_size = 2
    total_train_samples = 16
    train_batch_size = 4
    total_val_samples = 2
    val_batch_size = 1
    train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size)
    val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        max_epochs=1,
        val_check_interval=val_check_interval,
        accelerator="gpu",
        devices=world_size,
        strategy="ddp",
        enable_progress_bar=True,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

    total_train_batches = total_train_samples // (train_batch_size * world_size)
    val_check_batch = max(1, int(total_train_batches * val_check_interval))
    assert trainer.val_check_batch == val_check_batch
    val_checks_per_epoch = total_train_batches / val_check_batch
    total_val_batches = total_val_samples // (val_batch_size * world_size)
    pbar_callback = trainer.progress_bar_callback

    if trainer.is_global_zero:
        assert pbar_callback.val_progress_bar.n == total_val_batches
        assert pbar_callback.val_progress_bar.total == total_val_batches
        total_val_batches = total_val_batches * val_checks_per_epoch
        assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size
        assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size
        assert pbar_callback.is_enabled
def test_quantization_val_test_predict(tmpdir):
    """Test the default quantization aware training not affected by validating, testing and predicting."""
    seed_everything(42)
    num_features = 16
    dm = RegressDataModule(num_features=num_features)
    qmodel = RegressionModel()

    val_test_predict_qmodel = copy.deepcopy(qmodel)
    trainer = Trainer(
        callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        val_check_interval=1,
        num_sanity_val_steps=1,
        max_epochs=4,
    )
    trainer.fit(val_test_predict_qmodel, datamodule=dm)
    trainer.validate(model=val_test_predict_qmodel,
                     datamodule=dm,
                     verbose=False)
    trainer.test(model=val_test_predict_qmodel, datamodule=dm, verbose=False)
    trainer.predict(model=val_test_predict_qmodel,
                    dataloaders=[
                        torch.utils.data.DataLoader(
                            RandomDataset(num_features, 16))
                    ])

    expected_qmodel = copy.deepcopy(qmodel)
    # No validation in ``expected_qmodel`` fitting.
    Trainer(
        callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=0,
        max_epochs=4,
    ).fit(expected_qmodel, datamodule=dm)

    expected_state_dict = expected_qmodel.state_dict()
    for key, value in val_test_predict_qmodel.state_dict().items():
        expected_value = expected_state_dict[key]
        assert torch.allclose(value, expected_value)
Пример #19
0
def test_amp_cpus(tmpdir, strategy, precision, devices):
    """Make sure combinations of AMP and strategies work if supported."""
    tutils.reset_seed()

    trainer = Trainer(
        default_root_dir=tmpdir,
        accelerator="cpu",
        devices=devices,
        max_epochs=1,
        strategy=strategy,
        precision=precision,
    )

    model = AMPTestModel()
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model, DataLoader(RandomDataset(32, 64)))

    assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_prediction_writer_hook_call_intervals():
    """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
    interval."""
    DummyPredictionWriter.write_on_batch_end = Mock()
    DummyPredictionWriter.write_on_epoch_end = Mock()

    dataloader = DataLoader(RandomDataset(32, 64))

    model = BoringModel()
    cb = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    results = trainer.predict(model, dataloaders=dataloader)
    assert len(results) == 4
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 1

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 1

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("batch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 0

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 0
    assert cb.write_on_epoch_end.call_count == 1
def test_batch_level_batch_indices():
    """Test that batch_indices are returned when `return_predictions=False`."""
    DummyPredictionWriter.write_on_batch_end = Mock()

    class CustomBoringModel(BoringModel):
        def on_predict_epoch_end(self, *args, **kwargs):
            assert self.trainer.predict_loop.epoch_batch_indices == [[]]

    writer = DummyPredictionWriter("batch")
    model = CustomBoringModel()
    dataloader = DataLoader(RandomDataset(32, 64), batch_size=4)
    trainer = Trainer(limit_predict_batches=4, callbacks=writer)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)

    writer.write_on_batch_end.assert_has_calls([
        call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
        call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
        call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
        call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
    ])
def test_loader_detaching():
    """Checks that the loader has been reset after the entrypoint."""

    loader = DataLoader(RandomDataset(32, 10), batch_size=1)

    model = LoaderTestModel()

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer = Trainer(fast_dev_run=1)
    trainer.fit(model, loader, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.validate(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.predict(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.test(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64
def test_prediction_writer_batch_indices(num_workers):
    DummyPredictionWriter.write_on_batch_end = Mock()
    DummyPredictionWriter.write_on_epoch_end = Mock()

    dataloader = DataLoader(RandomDataset(32, 64),
                            batch_size=4,
                            num_workers=num_workers)
    model = BoringModel()
    writer = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=writer)
    trainer.predict(model, dataloaders=dataloader)

    writer.write_on_batch_end.assert_has_calls([
        call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
        call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
        call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
        call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
    ])

    writer.write_on_epoch_end.assert_has_calls([
        call(trainer, model, ANY,
             [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]),
    ])
Пример #24
0
def test_model_tpu_early_stop(tmpdir):
    """Test if single TPU core training works."""
    class CustomBoringModel(BoringModel):
        def validation_step(self, *args, **kwargs):
            out = super().validation_step(*args, **kwargs)
            self.log("val_loss", out["x"])
            return out

    tutils.reset_seed()
    model = CustomBoringModel()
    trainer = Trainer(
        callbacks=[EarlyStopping(monitor="val_loss")],
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        accelerator="tpu",
        devices=8,
    )
    trainer.fit(model)
    trainer.test(
        dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
 def test_dataloader(self):
     return DataLoader(RandomDataset(32, 64), batch_size=4)
 def val_dataloader(self):
     return DataLoader(RandomDataset(32, 64), shuffle=shuffle)
def test_pre_made_batches():
    """Check that loader works with pre-made batches."""
    loader = DataLoader(RandomDataset(32, 10), batch_size=None)
    trainer = Trainer(fast_dev_run=1)
    trainer.predict(LoaderTestModel(), loader)
 def test_dataloader(self):
     return DataLoader(RandomDataset(32, 64))
 def train_dataloader(self):
     return DataLoader(RandomDataset(32, 4),
                       collate_fn=self.collate_none_when_even)
def test_dataloaders_with_missing_keyword_arguments():
    ds = RandomDataset(10, 20)

    class TestDataLoader(DataLoader):
        def __init__(self, dataset):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    match = escape(
        "missing arguments are ['batch_sampler', 'sampler', 'shuffle']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, dataset, *args, **kwargs):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    _update_dataloader(loader, sampler, mode="fit")
    _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, *foo, **bar):
            super().__init__(*foo, **bar)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    _update_dataloader(loader, sampler, mode="fit")
    _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, *args, shuffle=False):
            self.num_feat = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing arguments are ['batch_sampler', 'sampler']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, **kwargs):
            self.feat_num = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="fit")
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="predict")