コード例 #1
0
def test_multiple_val_dataloader(tmpdir):
    """Verify multiple val_dataloader."""

    model = EvalModelTemplate()
    model.val_dataloader = model.val_dataloader__multiple
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=1.0,
    )
    result = trainer.fit(model)

    # verify training completed
    assert result == 1

    # verify there are 2 val loaders
    assert len(trainer.val_dataloaders) == 2, \
        'Multiple val_dataloaders not initiated properly'

    # make sure predictions are good for each val set
    for dataloader in trainer.val_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)
コード例 #2
0
def test_multiple_test_dataloader(tmpdir, ckpt_path):
    """Verify multiple test_dataloader."""

    model_template = EvalModelTemplate()

    class MultipleTestDataloaderModel(EvalModelTemplate):
        def test_dataloader(self):
            return model_template.test_dataloader__multiple()

        def test_step(self, batch, batch_idx, *args, **kwargs):
            return model_template.test_step__multiple_dataloaders(
                batch, batch_idx, *args, **kwargs)

    model = MultipleTestDataloaderModel()

    # fit model
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      limit_val_batches=0.1,
                      limit_train_batches=0.2)
    trainer.fit(model)
    if ckpt_path == 'specific':
        ckpt_path = trainer.checkpoint_callback.best_model_path
    trainer.test(ckpt_path=ckpt_path)

    # verify there are 2 test loaders
    assert len(trainer.test_dataloaders) == 2, \
        'Multiple test_dataloaders not initiated properly'

    # make sure predictions are good for each test set
    for dataloader in trainer.test_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)

    # run the test method
    trainer.test(ckpt_path=ckpt_path)
コード例 #3
0
def test_multiple_test_dataloader(tmpdir):
    """Verify multiple test_dataloader."""

    model = EvalModelTemplate()
    model.test_dataloader = model.test_dataloader__multiple
    model.test_step = model.test_step__multiple_dataloaders

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )
    trainer.fit(model)
    trainer.test()

    # verify there are 2 test loaders
    assert len(trainer.test_dataloaders) == 2, \
        'Multiple test_dataloaders not initiated properly'

    # make sure predictions are good for each test set
    for dataloader in trainer.test_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)

    # run the test method
    trainer.test()
コード例 #4
0
def test_multiple_test_dataloader(tmpdir):
    """Verify multiple test_dataloader."""
    class CurrentTestModel(
            LightTrainDataloader,
            LightTestMultipleDataloadersMixin,
            LightEmptyTestStep,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.2)
    trainer.fit(model)
    trainer.test()

    # verify there are 2 test loaders
    assert len(trainer.test_dataloaders) == 2, \
        'Multiple test_dataloaders not initiated properly'

    # make sure predictions are good for each test set
    for dataloader in trainer.test_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)

    # run the test method
    trainer.test()
コード例 #5
0
def test_multiple_val_dataloader(tmpdir):
    """Verify multiple val_dataloader."""
    class CurrentTestModel(
            LightTrainDataloader,
            LightValidationMultipleDataloadersMixin,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=1.0,
    )
    result = trainer.fit(model)

    # verify training completed
    assert result == 1

    # verify there are 2 val loaders
    assert len(trainer.val_dataloaders) == 2, \
        'Multiple val_dataloaders not initiated properly'

    # make sure predictions are good for each val set
    for dataloader in trainer.val_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)
コード例 #6
0
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        dataloader = trainer.train_dataloader
        tutils.run_prediction(dataloader, dp_model, dp=True)
コード例 #7
0
def test_running_test_pretrained_model_distrib(tmpdir, backend):
    """Verify `test()` on pretrained model."""

    tutils.reset_seed()
    tutils.set_random_master_port()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        train_percent_check=0.4,
        val_percent_check=0.2,
        checkpoint_callback=checkpoint,
        logger=logger,
        gpus=[0, 1],
        distributed_backend=backend,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert result == 1, 'training failed to complete'
    pretrained_model = tutils.load_model(logger,
                                         trainer.checkpoint_callback.dirpath,
                                         module_class=LightningTestModel)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer)

    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tutils.run_prediction(dataloader, pretrained_model)