コード例 #1
0
def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
    """Verify that dataloaders can be passed."""

    model = EvalModelTemplate()
    if n == 1:
        dataloaders = model.dataloader(train=False)
    else:
        dataloaders = [model.dataloader(train=False)] * 2
        model.validation_step = model.validation_step__multiple_dataloaders
        model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
        model.test_step = model.test_step__multiple_dataloaders

    # train, multiple val and multiple test passed to fit
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
    )
    trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)

    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert len(trainer.val_dataloaders) == n

    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path

    trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
    trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)

    assert len(trainer.val_dataloaders) == n
    assert len(trainer.test_dataloaders) == n
コード例 #2
0
def test_warning_with_few_workers(tmpdir):
    """ Test that error is raised if dataloader with only a few workers is used """

    model = EvalModelTemplate()

    # logger file to get meta
    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )

    fit_options = dict(train_dataloader=model.dataloader(train=True),
                       val_dataloaders=model.dataloader(train=False))
    test_options = dict(test_dataloaders=model.dataloader(train=False))

    trainer = Trainer(**trainer_options)

    # fit model
    with pytest.warns(UserWarning, match='train'):
        trainer.fit(model, **fit_options)

    with pytest.warns(UserWarning, match='val'):
        trainer.fit(model, **fit_options)

    with pytest.warns(UserWarning, match='test'):
        trainer.test(**test_options)
コード例 #3
0
def test_warning_with_few_workers(tmpdir, ckpt_path):
    """ Test that error is raised if dataloader with only a few workers is used """

    model = EvalModelTemplate()

    # logger file to get meta
    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           limit_val_batches=0.1,
                           limit_train_batches=0.2)

    train_dl = model.dataloader(train=True)
    train_dl.num_workers = 0

    val_dl = model.dataloader(train=False)
    val_dl.num_workers = 0

    train_dl = model.dataloader(train=False)
    train_dl.num_workers = 0

    fit_options = dict(train_dataloader=train_dl, val_dataloaders=val_dl)
    trainer = Trainer(**trainer_options)

    # fit model
    with pytest.warns(UserWarning, match='train'):
        trainer.fit(model, **fit_options)

    with pytest.warns(UserWarning, match='val'):
        trainer.fit(model, **fit_options)

    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path
    test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path)
    with pytest.warns(UserWarning, match='test'):
        trainer.test(**test_options)
コード例 #4
0
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
    """Verify that dataloaders can be passed to fit"""

    model = EvalModelTemplate()

    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           limit_val_batches=0.1,
                           limit_train_batches=0.2)

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
    assert results

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
    assert results
    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path
    trainer.test(test_dataloaders=model.dataloader(train=False),
                 ckpt_path=ckpt_path)

    assert len(trainer.val_dataloaders) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 1, \
        f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
コード例 #5
0
def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
    """Verify train, val & test dataloader(s) can be passed to fit and test method"""

    model = EvalModelTemplate()

    # train, val and test passed to fit
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      limit_val_batches=0.1,
                      limit_train_batches=0.2)
    fit_options = dict(train_dataloader=model.dataloader(train=True),
                       val_dataloaders=model.dataloader(train=False))
    result = trainer.fit(model, **fit_options)

    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path
    test_options = dict(test_dataloaders=model.dataloader(train=False),
                        ckpt_path=ckpt_path)
    trainer.test(**test_options)

    assert result == 1
    assert len(trainer.val_dataloaders) == 1, \
        f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 1, \
        f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
コード例 #6
0
def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
    """Verify that multiple val & test dataloaders can be passed to fit."""

    model = EvalModelTemplate()
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders

    # train, multiple val and multiple test passed to fit
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      limit_val_batches=0.1,
                      limit_train_batches=0.2)
    fit_options = dict(train_dataloader=model.dataloader(train=True),
                       val_dataloaders=[
                           model.dataloader(train=False),
                           model.dataloader(train=False)
                       ])
    trainer.fit(model, **fit_options)
    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path
    test_options = dict(test_dataloaders=[
        model.dataloader(train=False),
        model.dataloader(train=False)
    ],
                        ckpt_path=ckpt_path)
    trainer.test(**test_options)

    assert len(trainer.val_dataloaders) == 2, \
        f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 2, \
        f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
コード例 #7
0
def test_mixing_of_dataloader_options(tmpdir):
    """Verify that dataloaders can be passed to fit"""

    model = EvalModelTemplate()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
    assert results

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model, val_dataloaders=model.dataloader(train=False))
    assert results
    trainer.test(test_dataloaders=model.dataloader(train=False))

    assert len(trainer.val_dataloaders) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 1, \
        f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
コード例 #8
0
def test_multiple_dataloaders_passed_to_fit(tmpdir):
    """Verify that multiple val & test dataloaders can be passed to fit."""

    model = EvalModelTemplate()
    model.validation_step = model.validation_step__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders

    # train, multiple val and multiple test passed to fit
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )
    fit_options = dict(train_dataloader=model.dataloader(train=True),
                       val_dataloaders=[model.dataloader(train=False),
                                        model.dataloader(train=False)])
    test_options = dict(test_dataloaders=[model.dataloader(train=False),
                                          model.dataloader(train=False)])

    trainer.fit(model, **fit_options)
    trainer.test(**test_options)

    assert len(trainer.val_dataloaders) == 2, \
        f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 2, \
        f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
コード例 #9
0
def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path):
    """ Test that error is raised if dataloader with only a few workers is used """

    model = EvalModelTemplate()
    model.training_step = model.training_step__multiple_dataloaders
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders
    model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

    # logger file to get meta
    train_dl = model.dataloader(train=True)
    train_dl.num_workers = 0

    val_dl = model.dataloader(train=False)
    val_dl.num_workers = 0

    train_dl = model.dataloader(train=False)
    train_dl.num_workers = 0

    train_multi_dl = {'a': train_dl, 'b': train_dl}
    val_multi_dl = [val_dl, val_dl]
    test_multi_dl = [train_dl, train_dl]

    fit_options = dict(train_dataloader=train_multi_dl,
                       val_dataloaders=val_multi_dl)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
    )

    # fit model
    with pytest.warns(
            UserWarning,
            match=
            'The dataloader, train dataloader, does not have many workers which may be a bottleneck.'
    ):
        trainer.fit(model, **fit_options)

    with pytest.warns(
            UserWarning,
            match=
            'The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.'
    ):
        trainer.fit(model, **fit_options)

    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path
    test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path)
    with pytest.warns(
            UserWarning,
            match=
            'The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.'
    ):
        trainer.test(**test_options)
コード例 #10
0
def test_warning_on_wrong_test_settigs(tmpdir):
    """ Test the following cases related to test configuration of model:
        * error if `test_dataloader()` is overriden but `test_step()` is not
        * if both `test_dataloader()` and `test_step()` is overriden,
            throw warning if `test_epoch_end()` is not defined
        * error if `test_step()` is overriden but `test_dataloader()` is not
    """
    tutils.reset_seed()
    hparams = tutils.get_default_hparams()
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    # ----------------
    # if have test_dataloader should  have test_step
    # ----------------
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(hparams)
        model.test_step = None
        trainer.fit(model)

    # ----------------
    # if have test_dataloader  and  test_step recommend test_epoch_end
    # ----------------
    with pytest.warns(RuntimeWarning):
        model = EvalModelTemplate(hparams)
        model.test_epoch_end = None
        trainer.test(model)

    # ----------------
    # if have test_step and NO test_dataloader passed in tell user to pass test_dataloader
    # ----------------
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(hparams)
        model.test_dataloader = lambda: None
        trainer.test(model)

    # ----------------
    # if have test_dataloader and NO test_step tell user to implement  test_step
    # ----------------
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(hparams)
        model.test_dataloader = lambda: None
        model.test_step = None
        trainer.test(model, test_dataloaders=model.dataloader(train=False))

    # ----------------
    # if have test_dataloader and test_step but no test_epoch_end warn user
    # ----------------
    with pytest.warns(RuntimeWarning):
        model = EvalModelTemplate(hparams)
        model.test_dataloader = lambda: None
        model.test_epoch_end = None
        trainer.test(model, test_dataloaders=model.dataloader(train=False))
コード例 #11
0
def test_train_val_dataloaders_passed_to_fit(tmpdir):
    """ Verify that train & val dataloader can be passed to fit """

    # train, val passed to fit
    model = EvalModelTemplate()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      limit_val_batches=0.1,
                      limit_train_batches=0.2)
    fit_options = dict(train_dataloader=model.dataloader(train=True),
                       val_dataloaders=model.dataloader(train=False))

    result = trainer.fit(model, **fit_options)
    assert result == 1
    assert len(trainer.val_dataloaders) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
コード例 #12
0
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
    """ Test that error is raised if dataloader with only a few workers is used """

    model = EvalModelTemplate()
    model.training_step = model.training_step__multiple_dataloaders
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders
    model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

    val_dl = model.dataloader(train=False)
    val_dl.num_workers = 0

    train_dl = model.dataloader(train=False)
    train_dl.num_workers = 0

    train_multi_dl = {'a': train_dl, 'b': train_dl}
    val_multi_dl = [val_dl, val_dl]
    test_multi_dl = [train_dl, train_dl]

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
    )

    with pytest.warns(
            UserWarning,
            match=
            f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
    ):
        if stage == 'test':
            ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
            trainer.test(model,
                         test_dataloaders=test_multi_dl,
                         ckpt_path=ckpt_path)
        else:
            trainer.fit(model,
                        train_dataloader=train_multi_dl,
                        val_dataloaders=val_multi_dl)
コード例 #13
0
def test_train_dataloader_passed_to_fit(tmpdir):
    """Verify that train dataloader can be passed to fit """

    # only train passed to fit
    model = EvalModelTemplate()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      limit_val_batches=0.1,
                      limit_train_batches=0.2)
    fit_options = dict(train_dataloader=model.dataloader(train=True))
    result = trainer.fit(model, **fit_options)

    assert result == 1
コード例 #14
0
def test_all_dataloaders_passed_to_fit(tmpdir):
    """Verify train, val & test dataloader(s) can be passed to fit and test method"""

    model = EvalModelTemplate()

    # train, val and test passed to fit
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.2)
    fit_options = dict(train_dataloader=model.dataloader(train=True),
                       val_dataloaders=model.dataloader(train=False))
    test_options = dict(test_dataloaders=model.dataloader(train=False))

    result = trainer.fit(model, **fit_options)
    trainer.test(**test_options)

    assert result == 1
    assert len(trainer.val_dataloaders) == 1, \
        f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 1, \
        f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
コード例 #15
0
def test_train_dataloader_passed_to_fit(tmpdir):
    """Verify that train dataloader can be passed to fit """

    # only train passed to fit
    model = EvalModelTemplate(tutils.get_default_hparams())
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )
    fit_options = dict(train_dataloader=model.dataloader(train=True))
    result = trainer.fit(model, **fit_options)

    assert result == 1
コード例 #16
0
def test_error_on_dataloader_passed_to_fit(tmpdir):
    """Verify that when the auto scale batch size feature raises an error
       if a train dataloader is passed to fit """

    # only train passed to fit
    model = EvalModelTemplate()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.2,
                      auto_scale_batch_size='power')
    fit_options = dict(train_dataloader=model.dataloader(train=True))

    with pytest.raises(MisconfigurationException):
        trainer.fit(model, **fit_options)
コード例 #17
0
def test_train_dataloader_passed_to_fit(tmpdir):
    """Verify that train dataloader can be passed to fit """

    # only train passed to fit
    model = EvalModelTemplate()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
    )
    fit_options = dict(train_dataloader=model.dataloader(train=True))
    trainer.fit(model, **fit_options)

    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
コード例 #18
0
def test_test_loop_config(tmpdir):
    """"
    When either test loop or test data are missing
    """
    hparams = EvalModelTemplate.get_default_hparams()
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    # has test loop but no test data
    with pytest.warns(UserWarning):
        model = EvalModelTemplate(**hparams)
        model.test_dataloader = None
        trainer.test(model)

    # has test data but no test loop
    with pytest.warns(UserWarning):
        model = EvalModelTemplate(**hparams)
        model.test_step = None
        trainer.test(model, test_dataloaders=model.dataloader(train=False))