Esempio n. 1
0
def test_multiple_val_dataloader(tmpdir):
    """Verify multiple val_dataloader."""
    tutils.reset_seed()

    class CurrentTestModel(LightningValidationMultipleDataloadersMixin,
                           LightningTestModelBase):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=1.0,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    # verify training completed
    assert result == 1

    # verify there are 2 val loaders
    assert len(trainer.get_val_dataloaders()) == 2, \
        'Multiple val_dataloaders not initiated properly'

    # make sure predictions are good for each val set
    for dataloader in trainer.get_val_dataloaders():
        tutils.run_prediction(dataloader, trainer.model)
Esempio n. 2
0
def test_mixing_of_dataloader_options(tmpdir):
    """Verify that dataloaders can be passed to fit"""
    tutils.reset_seed()

    class CurrentTestModel(LightningTestModelBase):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # fit model
    trainer = Trainer(**trainer_options)
    fit_options = dict(val_dataloader=model._dataloader(train=False))
    results = trainer.fit(model, **fit_options)

    # fit model
    trainer = Trainer(**trainer_options)
    fit_options = dict(val_dataloader=model._dataloader(train=False),
                       test_dataloader=model._dataloader(train=False))
    results = trainer.fit(model, **fit_options)
    assert len(trainer.get_val_dataloaders()) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
    assert len(trainer.get_test_dataloaders()) == 1, \
        f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'
Esempio n. 3
0
def test_multiple_dataloaders_passed_to_fit(tmpdir):
    """ Verify that multiple val & test dataloaders can be passed to fit """
    tutils.reset_seed()

    class CurrentTestModel(LightningTestModelBaseWithoutDataloader):
        pass

    hparams = tutils.get_hparams()

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # train, multiple val and multiple test passed to fit
    model = CurrentTestModel(hparams)
    trainer = Trainer(**trainer_options)
    fit_options = dict(train_dataloader=model._dataloader(train=True),
                       val_dataloader=[
                           model._dataloader(train=False),
                           model._dataloader(train=False)
                       ],
                       test_dataloader=[
                           model._dataloader(train=False),
                           model._dataloader(train=False)
                       ])
    results = trainer.fit(model, **fit_options)

    assert len(trainer.get_val_dataloaders()) == 2, \
        f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
    assert len(trainer.get_test_dataloaders()) == 2, \
        f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'