Esempio n. 1
0
def test_multiple_test_dataloader(tmpdir):
    """Verify multiple test_dataloader."""
    tutils.reset_seed()

    class CurrentTestModel(LightningTestMultipleDataloadersMixin,
                           LightningTestModelBase):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    # verify there are 2 val loaders
    assert len(trainer.get_test_dataloaders()) == 2, \
        'Multiple test_dataloaders not initiated properly'

    # make sure predictions are good for each test set
    for dataloader in trainer.get_test_dataloaders():
        tutils.run_prediction(dataloader, trainer.model)

    # run the test method
    trainer.test()
Esempio n. 2
0
def test_mixing_of_dataloader_options(tmpdir):
    """Verify that dataloaders can be passed to fit"""
    tutils.reset_seed()

    class CurrentTestModel(LightningTestModelBase):
        pass

    hparams = tutils.get_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # fit model
    trainer = Trainer(**trainer_options)
    fit_options = dict(val_dataloader=model._dataloader(train=False))
    results = trainer.fit(model, **fit_options)

    # fit model
    trainer = Trainer(**trainer_options)
    fit_options = dict(val_dataloader=model._dataloader(train=False),
                       test_dataloader=model._dataloader(train=False))
    results = trainer.fit(model, **fit_options)
    assert len(trainer.get_val_dataloaders()) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
    assert len(trainer.get_test_dataloaders()) == 1, \
        f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'
Esempio n. 3
0
def test_multiple_dataloaders_passed_to_fit(tmpdir):
    """ Verify that multiple val & test dataloaders can be passed to fit """
    tutils.reset_seed()

    class CurrentTestModel(LightningTestModelBaseWithoutDataloader):
        pass

    hparams = tutils.get_hparams()

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # train, multiple val and multiple test passed to fit
    model = CurrentTestModel(hparams)
    trainer = Trainer(**trainer_options)
    fit_options = dict(train_dataloader=model._dataloader(train=True),
                       val_dataloader=[
                           model._dataloader(train=False),
                           model._dataloader(train=False)
                       ],
                       test_dataloader=[
                           model._dataloader(train=False),
                           model._dataloader(train=False)
                       ])
    results = trainer.fit(model, **fit_options)

    assert len(trainer.get_val_dataloaders()) == 2, \
        f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
    assert len(trainer.get_test_dataloaders()) == 2, \
        f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'
Esempio n. 4
0
def test_amp_gpu_ddp_slurm_managed():
    """
    Make sure DDP + AMP work
    :return:
    """
    if not can_run_gpu_test():
        return

    reset_seed()

    # simulate setting slurm flags
    set_random_master_port()
    os.environ['SLURM_LOCALID'] = str(0)

    hparams = get_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(show_progress_bar=True,
                           max_nb_epochs=1,
                           gpus=[0],
                           distributed_backend='ddp',
                           use_amp=True)

    save_dir = init_save_dir()

    # exp file to get meta
    logger = get_test_tube_logger(False)

    # exp file to get weights
    checkpoint = init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['checkpoint_callback'] = checkpoint
    trainer_options['logger'] = logger

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # correct result and ok accuracy
    assert result == 1, 'amp + ddp model failed to complete'

    # test root model address
    assert trainer.resolve_root_node_address('abc') == 'abc'
    assert trainer.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
    assert trainer.resolve_root_node_address(
        'abc[23-24, 45-40, 40]') == 'abc23'

    # test model loading with a map_location
    pretrained_model = load_model(logger.experiment, save_dir)

    # test model preds
    [
        run_prediction(dataloader, pretrained_model)
        for dataloader in trainer.get_test_dataloaders()
    ]

    if trainer.use_ddp:
        # on hpc this would work fine... but need to hack it for the purpose of the test
        trainer.model = pretrained_model
        trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers(
        )

    # test HPC loading / saving
    trainer.hpc_save(save_dir, logger)
    trainer.hpc_load(save_dir, on_gpu=True)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()

    clear_save_dir()