コード例 #1
0
def test_testpass_overrides(tmpdir):
    # todo: check duplicated tests against trainer_checks
    hparams = EvalModelTemplate.get_default_hparams()

    # Misconfig when neither test_step or test_end is implemented
    with pytest.raises(MisconfigurationException,
                       match='.*not implement `test_dataloader`.*'):
        model = EvalModelTemplate(**hparams)
        model.test_dataloader = LightningModule.test_dataloader
        Trainer().test(model)

    # Misconfig when neither test_step or test_end is implemented
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(**hparams)
        model.test_step = LightningModule.test_step
        Trainer().test(model)

    # No exceptions when one or both of test_step or test_end are implemented
    model = EvalModelTemplate(**hparams)
    model.test_step_end = LightningModule.test_step_end
    Trainer().test(model)

    model = EvalModelTemplate(**hparams)
    Trainer().test(model)
コード例 #2
0
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
                                                limit_val_batches,
                                                limit_test_batches):
    """Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
    model = EvalModelTemplate()
    model.val_dataloader = model.val_dataloader__multiple_mixed_length
    model.test_dataloader = model.test_dataloader__multiple_mixed_length
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders
    model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

    # train, multiple val and multiple test passed with percent_check
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
    )
    trainer.fit(model)
    expected_train_batches = int(
        len(trainer.train_dataloader) * limit_train_batches)
    expected_val_batches = [
        int(len(dataloader) * limit_val_batches)
        for dataloader in trainer.val_dataloaders
    ]
    assert trainer.num_training_batches == expected_train_batches
    assert trainer.num_val_batches == expected_val_batches

    trainer.test(ckpt_path=None)
    expected_test_batches = [
        int(len(dataloader) * limit_test_batches)
        for dataloader in trainer.test_dataloaders
    ]
    assert trainer.num_test_batches == expected_test_batches
コード例 #3
0
def test_full_train_loop_with_results_obj_dp(tmpdir):
    os.environ['PL_DEV_DEBUG'] = '1'

    batches = 10
    epochs = 3

    model = EvalModelTemplate()
    model.validation_step = None
    model.test_step = None
    model.training_step = model.training_step_full_loop_result_obj_dp
    model.training_step_end = model.training_step_end_full_loop_result_obj_dp
    model.training_epoch_end = model.training_epoch_end_full_loop_result_obj_dp
    model.val_dataloader = None
    model.test_dataloader = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        distributed_backend='dp',
        gpus=[0, 1],
        max_epochs=epochs,
        early_stop_callback=True,
        row_log_interval=2,
        limit_train_batches=batches,
        weights_summary=None,
    )

    trainer.fit(model)

    # make sure we saw all the correct keys
    seen_keys = set()
    for metric in trainer.dev_debugger.logged_metrics:
        seen_keys.update(metric.keys())

    assert 'train_step_metric' in seen_keys
    assert 'train_step_end_metric' in seen_keys
    assert 'epoch_train_epoch_end_metric' in seen_keys
コード例 #4
0
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
    """ Test that error is raised if dataloader with only a few workers is used """

    model = EvalModelTemplate()
    model.training_step = model.training_step__multiple_dataloaders
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders
    model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

    val_dl = model.dataloader(train=False)
    val_dl.num_workers = 0

    train_dl = model.dataloader(train=False)
    train_dl.num_workers = 0

    train_multi_dl = {'a': train_dl, 'b': train_dl}
    val_multi_dl = [val_dl, val_dl]
    test_multi_dl = [train_dl, train_dl]

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
    )

    with pytest.warns(
        UserWarning,
        match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers'
    ):
        if stage == 'test':
            ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path
            trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path)
        else:
            trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
コード例 #5
0
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches,
                                            limit_val_batches,
                                            limit_test_batches):
    """Verify num_batches for train, val & test dataloaders passed with batch limit as number"""

    model = EvalModelTemplate()
    model.val_dataloader = model.val_dataloader__multiple_mixed_length
    model.test_dataloader = model.test_dataloader__multiple_mixed_length
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders
    model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

    # train, multiple val and multiple test passed with percent_check
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
    )
    trainer.fit(model)

    # -------------------------------------------
    # MAKE SURE THE TRAINER SET THE CORRECT VALUES
    # -------------------------------------------
    assert trainer.num_training_batches == limit_train_batches
    assert trainer.num_val_batches == [limit_val_batches] * len(
        trainer.val_dataloaders)
    trainer.test(ckpt_path=None)

    # when the limit is greater than the number of test batches it should be the num in loaders
    test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
    if limit_test_batches > 1e10:
        assert trainer.num_test_batches == test_dataloader_lengths
    else:
        assert trainer.num_test_batches == [limit_test_batches] * len(
            trainer.test_dataloaders)

    # -------------------------------------------
    # make sure we actually saw the expected num of batches
    # -------------------------------------------
    num_val_dataloaders = len(model.val_dataloader())
    num_test_dataloaders = len(model.test_dataloader())
    if limit_train_batches > 0:

        # make sure val batches are as expected
        assert len(trainer.dev_debugger.num_seen_val_check_batches
                   ) == num_val_dataloaders
        for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items(
        ):
            assert num_batches == limit_val_batches

        # make sure test batches are as expected
        assert len(trainer.dev_debugger.num_seen_test_check_batches
                   ) == num_test_dataloaders
        for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items(
        ):
            if limit_test_batches > 1e10:
                assert num_batches == test_dataloader_lengths[dataloader_idx]
            else:
                assert num_batches == limit_test_batches
コード例 #6
0
def test_wrong_test_settigs(tmpdir):
    """ Test the following cases related to test configuration of model:
        * error if `test_dataloader()` is overridden but `test_step()` is not
        * if both `test_dataloader()` and `test_step()` is overridden,
            throw warning if `test_epoch_end()` is not defined
        * error if `test_step()` is overridden but `test_dataloader()` is not
    """
    hparams = EvalModelTemplate.get_default_hparams()
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    # ----------------
    # if have test_dataloader should  have test_step
    # ----------------
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(hparams)
        model.test_step = None
        trainer.fit(model)

    # ----------------
    # if have test_dataloader  and  test_step recommend test_epoch_end
    # ----------------
    with pytest.warns(RuntimeWarning):
        model = EvalModelTemplate(hparams)
        model.test_epoch_end = None
        trainer.test(model)

    # ----------------
    # if have test_step and NO test_dataloader passed in tell user to pass test_dataloader
    # ----------------
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(hparams)
        model.test_dataloader = LightningModule.test_dataloader
        trainer.test(model)

    # ----------------
    # if have test_dataloader and NO test_step tell user to implement  test_step
    # ----------------
    with pytest.raises(MisconfigurationException):
        model = EvalModelTemplate(hparams)
        model.test_dataloader = LightningModule.test_dataloader
        model.test_step = None
        trainer.test(model, test_dataloaders=model.dataloader(train=False))

    # ----------------
    # if have test_dataloader and test_step but no test_epoch_end warn user
    # ----------------
    with pytest.warns(RuntimeWarning):
        model = EvalModelTemplate(hparams)
        model.test_dataloader = LightningModule.test_dataloader
        model.test_epoch_end = None
        trainer.test(model, test_dataloaders=model.dataloader(train=False))

    # ----------------
    # if we are just testing, no need for train_dataloader, train_step, val_dataloader, and val_step
    # ----------------
    model = EvalModelTemplate(hparams)
    model.test_dataloader = LightningModule.test_dataloader
    model.train_dataloader = None
    model.train_step = None
    model.val_dataloader = None
    model.val_step = None
    trainer.test(model, test_dataloaders=model.dataloader(train=False))