예제 #1
0
def test_num_dataloader_batches(tmpdir):
    """Tests that the correct number of batches are allocated."""
    # when we have fewer batches in the dataloader we should use those instead of the limit
    model = BoringModel()
    trainer = Trainer(limit_val_batches=100,
                      limit_train_batches=100,
                      max_epochs=1,
                      default_root_dir=tmpdir)
    trainer.fit(model)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert isinstance(trainer.num_val_batches, list)
    assert trainer.num_val_batches[0] == 64
    assert trainer.num_training_batches == 64

    # when we have more batches in the dataloader we should limit them
    model = BoringModel()
    trainer = Trainer(limit_val_batches=7,
                      limit_train_batches=7,
                      max_epochs=1,
                      default_root_dir=tmpdir)
    trainer.fit(model)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert isinstance(trainer.num_val_batches, list)
    assert trainer.num_val_batches[0] == 7
    assert trainer.num_training_batches == 7
예제 #2
0
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
    """ Test that error is raised if dataloader with only a few workers is used """

    model = BoringModel()

    train_dl = model.train_dataloader()
    train_dl.num_workers = 0

    val_dl = model.val_dataloader()
    val_dl.num_workers = 0

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
    )

    with pytest.warns(
        UserWarning,
        match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
    ):
        if stage == 'test':
            ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
            trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path)
        else:
            trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
예제 #3
0
def test_dataloader_source_request_from_module():
    """Test requesting a dataloader from a module works."""
    module = BoringModel()
    module.trainer = Trainer()
    module.foo = Mock(return_value=module.train_dataloader())

    source = _DataLoaderSource(module, "foo")
    assert source.is_module()
    module.foo.assert_not_called()
    assert isinstance(source.dataloader(), DataLoader)
    module.foo.assert_called_once()
예제 #4
0
def test_warning_with_iterable_dataset_and_len(tmpdir):
    """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """
    model = BoringModel()
    original_dataset = model.train_dataloader().dataset

    class IterableWithoutLen(IterableDataset):

        def __iter__(self):
            return iter(original_dataset)

    class IterableWithLen(IterableWithoutLen):

        def __len__(self):
            return len(original_dataset)

    # with __len__ defined
    dataloader = DataLoader(IterableWithLen(), batch_size=16)
    assert has_len(dataloader)
    assert has_iterable_dataset(dataloader)
    trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
    with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
        trainer.validate(model, val_dataloaders=[dataloader])
    with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
        trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
    with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
        trainer.test(model, test_dataloaders=[dataloader])
    with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
        trainer.predict(model, dataloaders=[dataloader])

    # without __len__ defined
    dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
    assert not has_len(dataloader)
    assert has_iterable_dataset(dataloader)
    trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
    trainer.validate(model, val_dataloaders=dataloader)
    trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
    trainer.test(model, test_dataloaders=dataloader)
    trainer.predict(model, dataloaders=dataloader)