Пример #1
0
def test_running_test_without_val(tmpdir):
    """Verify `test()` works on a model with no `val_loader`."""
    tutils.reset_seed()

    class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    logger = tutils.get_default_testtube_logger(tmpdir, False)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=1,
        train_percent_check=0.4,
        val_percent_check=0.2,
        test_percent_check=0.2,
        checkpoint_callback=checkpoint,
        logger=logger,
        early_stop_callback=False
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    assert result == 1, 'training failed to complete'

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer)
Пример #2
0
def test_optimizer_with_scheduling(tmpdir):
    """ Verify that learning rate scheduling is working """
    tutils.reset_seed()

    class CurrentTestModel(
            LightTestOptimizerWithSchedulingMixin,
            LightTrainDataloader,
            TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model)

    init_lr = hparams.learning_rate
    adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups]

    assert len(trainer.lr_schedulers) == 1, \
        'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers)

    assert all(a == adjusted_lr[0] for a in adjusted_lr), \
        'Lr not equally adjusted for all param groups'
    adjusted_lr = adjusted_lr[0]

    assert init_lr * 0.1 == adjusted_lr, \
        'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr)
def test_mixing_of_dataloader_options(tmpdir):
    """Verify that dataloaders can be passed to fit"""
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            LightValStepFitSingleDataloaderMixin,
            LightTestFitSingleTestDataloadersMixin,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # fit model
    trainer = Trainer(**trainer_options)
    fit_options = dict(val_dataloaders=model._dataloader(train=False))
    results = trainer.fit(model, **fit_options)

    # fit model
    trainer = Trainer(**trainer_options)
    fit_options = dict(val_dataloaders=model._dataloader(train=False),
                       test_dataloaders=model._dataloader(train=False))
    _ = trainer.fit(model, **fit_options)
    trainer.test()

    assert len(trainer.val_dataloaders) == 1, \
        f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 1, \
        f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
Пример #4
0
def test_all_dataloaders_passed_to_fit(tmpdir):
    """Verify train, val & test dataloader can be passed to fit """
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            LightValStepFitSingleDataloaderMixin,
            LightTestFitSingleTestDataloadersMixin,
            LightEmptyTestStep,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()

    # logger file to get meta
    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # train, val and test passed to fit
    model = CurrentTestModel(hparams)
    trainer = Trainer(**trainer_options)
    fit_options = dict(train_dataloader=model._dataloader(train=True),
                       val_dataloaders=model._dataloader(train=False))
    test_options = dict(test_dataloaders=model._dataloader(train=False))

    result = trainer.fit(model, **fit_options)

    trainer.test(**test_options)

    assert result == 1
    assert len(trainer.val_dataloaders) == 1, \
        f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 1, \
        f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
Пример #5
0
def test_no_val_module(tmpdir):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()

    class CurrentTestModel(LightTrainDataloader, TestModelBase):
        pass

    model = CurrentTestModel(hparams)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir)
    )
    # fit model
    result = trainer.fit(model)
    # training complete
    assert result == 1, 'amp + ddp model failed to complete'

    # save model
    new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
    trainer.save_checkpoint(new_weights_path)

    # assert ckpt has hparams
    ckpt = torch.load(new_weights_path)
    assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints'

    # won't load without hparams in the ckpt
    model_2 = LightningTestModel.load_from_checkpoint(
        checkpoint_path=new_weights_path,
    )
    model_2.eval()
Пример #6
0
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    tutils.reset_seed()

    # simulate setting slurm flags
    tutils.set_random_master_port()
    os.environ['SLURM_LOCALID'] = str(0)

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        max_epochs=1,
        gpus=[0],
        distributed_backend='ddp',
        precision=16,
        checkpoint_callback=checkpoint,
        logger=logger,
    )
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # correct result and ok accuracy
    assert result == 1, 'amp + ddp model failed to complete'

    # test root model address
    assert trainer.resolve_root_node_address('abc') == 'abc'
    assert trainer.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'
Пример #7
0
def test_running_test_pretrained_model(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    # logger file to get meta
    logger = tutils.get_default_testtube_logger(tmpdir, False)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        show_progress_bar=False,
        max_epochs=4,
        train_percent_check=0.4,
        val_percent_check=0.2,
        checkpoint_callback=checkpoint,
        logger=logger
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    # correct result and ok accuracy
    assert result == 1, 'training failed to complete'
    pretrained_model = tutils.load_model(
        logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
    )

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer)
Пример #8
0
def test_warning_with_few_workers(tmpdir):
    """ Test that error is raised if dataloader with only a few workers is used """
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            LightValStepFitSingleDataloaderMixin,
            LightTestFitSingleTestDataloadersMixin,
            LightEmptyTestStep,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    fit_options = dict(train_dataloader=model._dataloader(train=True),
                       val_dataloaders=model._dataloader(train=False))
    test_options = dict(test_dataloaders=model._dataloader(train=False))

    trainer = Trainer(**trainer_options)

    # fit model
    with pytest.warns(UserWarning, match='train'):
        trainer.fit(model, **fit_options)

    with pytest.warns(UserWarning, match='val'):
        trainer.fit(model, **fit_options)

    with pytest.warns(UserWarning, match='test'):
        trainer.test(**test_options)
Пример #9
0
def test_multiple_test_dataloader(tmpdir):
    """Verify multiple test_dataloader."""
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            LightTestMultipleDataloadersMixin,
            LightEmptyTestStep,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model)
    trainer.test()

    # verify there are 2 val loaders
    assert len(trainer.test_dataloaders) == 2, \
        'Multiple test_dataloaders not initiated properly'

    # make sure predictions are good for each test set
    for dataloader in trainer.test_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)

    # run the test method
    trainer.test()
Пример #10
0
def test_loading_meta_tags(tmpdir):
    """ test for backward compatibility to meta_tags.csv """
    tutils.reset_seed()

    hparams = EvalModelTemplate.get_default_hparams()

    # save tags
    logger = tutils.get_default_logger(tmpdir)
    logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0))
    logger.log_hyperparams(hparams)
    logger.save()

    # load hparams
    path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir)
    hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE)
    hparams = load_hparams_from_yaml(hparams_path)

    # save as legacy meta_tags.csv
    tags_path = os.path.join(path_expt_dir, 'meta_tags.csv')
    save_hparams_to_tags_csv(tags_path, hparams)

    tags = load_hparams_from_tags_csv(tags_path)

    assert hparams == tags
Пример #11
0
def test_no_val_end_module(tmpdir):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    tutils.reset_seed()

    class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir)
    )
    result = trainer.fit(model)

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # save model
    new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
    trainer.save_checkpoint(new_weights_path)

    # load new model
    tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
    tags_path = os.path.join(tags_path, 'meta_tags.csv')
    model_2 = LightningTestModel.load_from_checkpoint(
        checkpoint_path=new_weights_path,
        tags_csv=tags_path
    )
    model_2.eval()
Пример #12
0
def test_lr_logger_multi_lrs(tmpdir):
    """ Test that learning rates are extracted and logged for multi lr schedulers """
    tutils.reset_seed()

    class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin,
                           LightTrainDataloader, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    lr_logger = LearningRateLogger()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.5,
                      callbacks=[lr_logger])
    results = trainer.fit(model)

    assert lr_logger.lrs, 'No learning rates logged'
    assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
        'Number of learning rates logged does not match number of lr schedulers'
    assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
        'Names of learning rates not set correctly'
Пример #13
0
def test_trains_logger(tmpdir):
    """Verify that basic functionality of TRAINS logger works."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)
    TrainsLogger.set_bypass_mode(True)
    TrainsLogger.set_credentials(
        api_host='http://integration.trains.allegro.ai:8008',
        files_host='http://integration.trains.allegro.ai:8081',
        web_host='http://integration.trains.allegro.ai:8080',
    )
    logger = TrainsLogger(project_name="lightning_log",
                          task_name="pytorch lightning test")

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      train_percent_check=0.05,
                      logger=logger)
    result = trainer.fit(model)

    print('result finished')
    logger.finalize()
    assert result == 1, "Training failed"
Пример #14
0
def test_train_dataloaders_passed_to_fit(tmpdir):
    """Verify that train dataloader can be passed to fit """
    tutils.reset_seed()

    class CurrentTestModel(LightTrainDataloader, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )

    # only train passed to fit
    model = CurrentTestModel(hparams)
    trainer = Trainer(**trainer_options)
    fit_options = dict(train_dataloader=model._dataloader(train=True))
    result = trainer.fit(model, **fit_options)

    assert result == 1
Пример #15
0
def test_multiple_val_dataloader(tmpdir):
    """Verify multiple val_dataloader."""
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            LightValidationMultipleDataloadersMixin,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=1.0,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    # verify training completed
    assert result == 1

    # verify there are 2 val loaders
    assert len(trainer.val_dataloaders) == 2, \
        'Multiple val_dataloaders not initiated properly'

    # make sure predictions are good for each val set
    for dataloader in trainer.val_dataloaders:
        tutils.run_prediction(dataloader, trainer.model)
Пример #16
0
def test_neptune_leave_open_experiment_after_fit(tmpdir):
    """Verify that neptune experiment was closed after training"""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    def _run_training(logger):
        logger._experiment = MagicMock()

        trainer_options = dict(default_root_dir=tmpdir,
                               max_epochs=1,
                               train_percent_check=0.05,
                               logger=logger)
        trainer = Trainer(**trainer_options)
        trainer.fit(model)
        return logger

    logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
    assert logger_close_after_fit._experiment.stop.call_count == 1

    logger_open_after_fit = _run_training(
        NeptuneLogger(offline_mode=True, close_after_fit=False))
    assert logger_open_after_fit._experiment.stop.call_count == 0
Пример #17
0
def test_trainer_reset_correctly(tmpdir):
    ''' Check that all trainer parameters are reset correctly after lr_find() '''
    tutils.reset_seed()

    class CurrentTestModel(
        LightTrainDataloader,
        TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer = Trainer(
        default_save_path=tmpdir,
        max_epochs=1
    )

    changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
                          'progress_bar_refresh_rate',
                          'accumulate_grad_batches',
                          'checkpoint_callback']
    attributes_before = {}
    for ca in changed_attributes:
        attributes_before[ca] = getattr(trainer, ca)

    _ = trainer.lr_find(model, num_training=5)

    attributes_after = {}
    for ca in changed_attributes:
        attributes_after[ca] = getattr(trainer, ca)

    for key in changed_attributes:
        assert attributes_before[key] == attributes_after[key], \
            f'Attribute {key} was not reset correctly after learning rate finder'
Пример #18
0
def test_trainer_arg_str(tmpdir):
    tutils.reset_seed()

    class CurrentTestModel(
        LightTrainDataloader,
        TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    hparams.__dict__['my_fancy_lr'] = 1.0  # update with non-standard field
    model = CurrentTestModel(hparams)
    before_lr = hparams.my_fancy_lr
    # logger file to get meta
    trainer = Trainer(
        default_save_path=tmpdir,
        max_epochs=1,
        auto_lr_find='my_fancy_lr'
    )

    trainer.fit(model)
    after_lr = model.hparams.my_fancy_lr
    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder'
Пример #19
0
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    tutils.reset_seed()

    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0})

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})

    logger.log_hyperparams({'test': None})
    wandb.init().config.update.assert_called_once_with({'test': None})

    logger.watch('model', 'log', 10)
    wandb.watch.assert_called_once_with('model', log='log', log_freq=10)

    logger.finalize('fail')
    wandb.join.assert_called_once_with(1)

    wandb.join.reset_mock()
    logger.finalize('success')
    wandb.join.assert_called_once_with(0)

    wandb.join.reset_mock()
    wandb.join.side_effect = TypeError
    with pytest.raises(TypeError):
        logger.finalize('any')

    wandb.join.assert_called()

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Пример #20
0
def test_multiple_dataloaders_passed_to_fit(tmpdir):
    """Verify that multiple val & test dataloaders can be passed to fit."""
    tutils.reset_seed()

    class CurrentTestModel(
        LightningTestModel,
        LightValStepFitMultipleDataloadersMixin,
        LightTestFitMultipleTestDataloadersMixin,
    ):
        pass

    hparams = tutils.get_default_hparams()

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )

    # train, multiple val and multiple test passed to fit
    model = CurrentTestModel(hparams)
    trainer = Trainer(**trainer_options)
    fit_options = dict(train_dataloader=model._dataloader(train=True),
                       val_dataloaders=[model._dataloader(train=False),
                                        model._dataloader(train=False)],
                       test_dataloaders=[model._dataloader(train=False),
                                         model._dataloader(train=False)])
    results = trainer.fit(model, **fit_options)
    trainer.test()

    assert len(trainer.val_dataloaders) == 2, \
        f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
    assert len(trainer.test_dataloaders) == 2, \
        f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
Пример #21
0
def test_tbptt_cpu_model(tmpdir):
    """Test truncated back propagation through time works."""
    tutils.reset_seed()

    truncated_bptt_steps = 2
    sequence_size = 30
    batch_size = 30

    x_seq = torch.rand(batch_size, sequence_size, 1)
    y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()

    class MockSeq2SeqDataset(torch.utils.data.Dataset):
        def __getitem__(self, i):
            return x_seq, y_seq_list

        def __len__(self):
            return 1

    class BpttTestModel(LightTrainDataloader, TestModelBase):
        def __init__(self, hparams):
            super().__init__(hparams)
            self.test_hidden = None

        def training_step(self, batch, batch_idx, hiddens):
            assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
            self.test_hidden = torch.rand(1)

            x_tensor, y_list = batch
            assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"

            y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
            assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"

            pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
            loss_val = torch.nn.functional.mse_loss(
                pred, y_tensor.view(batch_size, truncated_bptt_steps))
            return {
                'loss': loss_val,
                'hiddens': self.test_hidden,
            }

        def train_dataloader(self):
            return torch.utils.data.DataLoader(
                dataset=MockSeq2SeqDataset(),
                batch_size=batch_size,
                shuffle=False,
                sampler=None,
            )

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        truncated_bptt_steps=truncated_bptt_steps,
        val_percent_check=0,
        weights_summary=None,
        early_stop_callback=False
    )

    hparams = tutils.get_default_hparams()
    hparams.batch_size = batch_size
    hparams.in_features = truncated_bptt_steps
    hparams.hidden_dim = truncated_bptt_steps
    hparams.out_features = truncated_bptt_steps

    model = BpttTestModel(hparams)

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    assert result == 1, 'training failed to complete'
Пример #22
0
def test_dataloader_config_errors(tmpdir):
    tutils.reset_seed()

    class CurrentTestModel(
            LightTrainDataloader,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # percent check < 0

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        train_percent_check=-0.1,
    )

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)

    # percent check > 1

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        train_percent_check=1.1,
    )

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)

    # int val_check_interval > num batches

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_check_interval=10000)

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)

    # float val_check_interval > 1

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_check_interval=1.1)

    # fit model
    trainer = Trainer(**trainer_options)

    with pytest.raises(ValueError):
        trainer.fit(model)
Пример #23
0
def test_cpu_slurm_save_load(tmpdir):
    """Verify model save/load/checkpoint on CPU."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    # logger file to get meta
    logger = tutils.get_default_testtube_logger(tmpdir, False)
    version = logger.version

    trainer_options = dict(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir)
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # predict with trained model before saving
    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    x, y = batch
    x = x.view(x.size(0), -1)

    model.eval()
    pred_before_saving = model(x)

    # test HPC saving
    # simulate snapshot on slurm
    saved_filepath = trainer.hpc_save(tmpdir, logger)
    assert os.path.exists(saved_filepath)

    # new logger file to get meta
    logger = tutils.get_default_testtube_logger(tmpdir, False, version=version)

    trainer_options = dict(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir),
    )
    trainer = Trainer(**trainer_options)
    model = LightningTestModel(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_pred_same():
        assert trainer.global_step == real_global_step and trainer.global_step > 0

        # predict with loaded model to make sure answers are the same
        trainer.model.eval()
        new_pred = trainer.model(x)
        assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1

    model.on_epoch_start = assert_pred_same

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)
Пример #24
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""

    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(
        max_epochs=1,
        gpus=2,
        distributed_backend='dp',
    )

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['checkpoint_callback'] = checkpoint

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert result == 1, 'amp + dp model failed to complete'

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
    trainer_options['train_percent_check'] = 0.5
    trainer_options['val_percent_check'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        dataloader = trainer.train_dataloader
        tutils.run_prediction(dataloader, dp_model, dp=True)

    # new model
    model = LightningTestModel(hparams)
    model.on_train_start = assert_good_acc

    # fit new model which should load hpc weights
    new_trainer.fit(model)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Пример #25
0
def test_optimizer_return_options():
    tutils.reset_seed()

    trainer = Trainer()
    model, hparams = tutils.get_default_model()

    # single optimizer
    opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
    opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
    scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
    scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)

    # single optimizer
    model.configure_optimizers = lambda: opt_a
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0

    # opt tuple
    model.configure_optimizers = lambda: (opt_a, opt_b)
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b
    assert len(lr_sched) == 0 and len(freq) == 0

    # opt list
    model.configure_optimizers = lambda: [opt_a, opt_b]
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b
    assert len(lr_sched) == 0 and len(freq) == 0

    # opt tuple of 2 lists
    model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a,
                               interval='epoch',
                               frequency=1,
                               reduce_on_plateau=False,
                               monitor='val_loss')

    # opt single dictionary
    model.configure_optimizers = lambda: {
        "optimizer": opt_a,
        "lr_scheduler": scheduler_a
    }
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a,
                               interval='epoch',
                               frequency=1,
                               reduce_on_plateau=False,
                               monitor='val_loss')

    # opt multiple dictionaries with frequencies
    model.configure_optimizers = lambda: (
        {
            "optimizer": opt_a,
            "lr_scheduler": scheduler_a,
            "frequency": 1
        },
        {
            "optimizer": opt_b,
            "lr_scheduler": scheduler_b,
            "frequency": 5
        },
    )
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a,
                               interval='epoch',
                               frequency=1,
                               reduce_on_plateau=False,
                               monitor='val_loss')
    assert freq == [1, 5]
Пример #26
0
def test_resume_from_checkpoint_epoch_restored(tmpdir):
    """Verify resuming from checkpoint runs the right number of epochs"""
    import types

    tutils.reset_seed()

    hparams = tutils.get_default_hparams()

    def _new_model():
        # Create a model that tracks epochs and batches seen
        model = LightningTestModel(hparams)
        model.num_epochs_seen = 0
        model.num_batches_seen = 0

        def increment_epoch(self):
            self.num_epochs_seen += 1

        def increment_batch(self, _):
            self.num_batches_seen += 1

        # Bind the increment_epoch function on_epoch_end so that the
        # model keeps track of the number of epochs it has seen.
        model.on_epoch_end = types.MethodType(increment_epoch, model)
        model.on_batch_start = types.MethodType(increment_batch, model)
        return model

    model = _new_model()

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        train_percent_check=0.65,
        val_percent_check=1,
        checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
        logger=False,
        default_save_path=tmpdir,
        early_stop_callback=False,
        val_check_interval=1.,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model)

    training_batches = trainer.num_training_batches

    assert model.num_epochs_seen == 2
    assert model.num_batches_seen == training_batches * 2

    # Other checkpoints can be uncommented if/when resuming mid-epoch is supported
    checkpoints = sorted(
        glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))

    for check in checkpoints:
        next_model = _new_model()
        state = torch.load(check)

        # Resume training
        trainer_options['max_epochs'] = 2
        new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
        new_trainer.fit(next_model)
        assert state[
            'global_step'] + next_model.num_batches_seen == training_batches * trainer_options[
                'max_epochs']
Пример #27
0
def test_gradient_accumulation_scheduling(tmpdir):
    """
    Test grad accumulation by the freq of optimizer updates
    """
    tutils.reset_seed()

    # test incorrect configs
    with pytest.raises(IndexError):
        assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6})
        assert Trainer(accumulate_grad_batches={-2: 3})

    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={})
        assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
        assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
        assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

    # test optimizer call freq matches scheduler
    def _optimizer_step(self,
                        epoch,
                        batch_idx,
                        optimizer,
                        optimizer_idx,
                        second_order_closure=None):
        # only test the first 12 batches in epoch
        if batch_idx < 12:
            if epoch == 0:
                # reset counter when starting epoch
                if batch_idx == 0:
                    self.prev_called_batch_idx = 0

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 1

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 1

            elif 1 <= epoch <= 2:
                # reset counter when starting epoch
                if batch_idx == 1:
                    self.prev_called_batch_idx = 1

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 2

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 2

            else:
                if batch_idx == 3:
                    self.prev_called_batch_idx = 3

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 4

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 3

        optimizer.step()

        # clear gradients
        optimizer.zero_grad()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)
    schedule = {1: 2, 3: 4}

    trainer = Trainer(accumulate_grad_batches=schedule,
                      train_percent_check=0.1,
                      val_percent_check=0.1,
                      max_epochs=2,
                      default_save_path=tmpdir)

    # for the test
    trainer.optimizer_step = _optimizer_step
    model.prev_called_batch_idx = 0

    trainer.fit(model)
def test_numpy_metric_ddp():
    tutils.reset_seed()
    tutils.set_random_master_port()
    world_size = 2
    mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)