示例#1
0
def test_train_val_dataloaders_passed_to_fit(tmpdir):
    """ Verify that train & val dataloader can be passed to fit """
    tutils.reset_seed()

    class CurrentTestModel(LightningTestModelBaseWithoutDataloader):
        pass

    hparams = tutils.get_hparams()

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # train, val passed to fit
    model = CurrentTestModel(hparams)
    trainer = Trainer(**trainer_options)
    fit_options = dict(train_dataloader=model._dataloader(train=True),
                       val_dataloader=model._dataloader(train=False))
    results = trainer.fit(model, **fit_options)
    assert len(trainer.get_val_dataloaders()) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
def test_inf_val_dataloader(tmpdir):
    """Test inf val data loader (e.g. IterableDataset)"""
    tutils.reset_seed()

    class CurrentTestModel(LightInfValDataloader, LightningTestModel):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    with pytest.raises(MisconfigurationException):
        trainer = Trainer(default_save_path=tmpdir,
                          max_epochs=1,
                          val_percent_check=0.5)
        trainer.fit(model)

    # logger file to get meta
    trainer = Trainer(default_save_path=tmpdir, max_epochs=1)
    result = trainer.fit(model)

    # verify training completed
    assert result == 1
示例#3
0
def test_running_test_without_val(tmpdir):
    """Verify `test()` works on a model with no `val_loader`."""
    tutils.reset_seed()

    class CurrentTestModel(LightningTestMixin, LightningTestModelBase):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    logger = tutils.get_test_tube_logger(tmpdir, False)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        show_progress_bar=False,
        max_epochs=1,
        train_percent_check=0.4,
        val_percent_check=0.2,
        test_percent_check=0.2,
        checkpoint_callback=checkpoint,
        logger=logger,
        early_stop_callback=False
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    assert result == 1, 'training failed to complete'

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer)
def test_trains_logger(tmpdir):
    """Verify that basic functionality of TRAINS logger works."""
    tutils.reset_seed()

    hparams = tutils.get_hparams()
    model = LightningTestModel(hparams)
    TrainsLogger.set_bypass_mode(True)
    TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008',
                                 files_host='http://integration.trains.allegro.ai:8081',
                                 web_host='http://integration.trains.allegro.ai:8080', )
    logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")

    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        train_percent_check=0.05,
        logger=logger
    )
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    print('result finished')
    logger.finalize()
    assert result == 1, "Training failed"
def test_dataloader_config_errors(tmpdir):
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # percent check < 0

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        train_percent_check=-0.1,
    )

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)

    # percent check > 1

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        train_percent_check=1.1,
    )

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)

    # int val_check_interval > num batches

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_check_interval=10000)

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)

    # float val_check_interval > 1

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_check_interval=1.1)

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)