Exemplo n.º 1
0
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    if not tutils.can_run_gpu_test():
        return

    tutils.reset_seed()

    # simulate setting slurm flags
    tutils.set_random_master_port()
    os.environ['SLURM_LOCALID'] = str(0)

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(show_progress_bar=True,
                           max_epochs=1,
                           gpus=[0],
                           distributed_backend='ddp',
                           precision=16)

    # exp file to get meta
    logger = tutils.get_default_testtube_logger(tmpdir, False)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['checkpoint_callback'] = checkpoint
    trainer_options['logger'] = logger

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # correct result and ok accuracy
    assert result == 1, 'amp + ddp model failed to complete'

    # test root model address
    assert trainer.resolve_root_node_address('abc') == 'abc'
    assert trainer.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
    assert trainer.resolve_root_node_address(
        'abc[23-24, 45-40, 40]') == 'abc23'
def test_call_to_trainer_method(tmpdir):

    hparams = tutils.get_default_hparams()
    model = EvalModelTemplate(hparams)

    before_lr = hparams.learning_rate
    # logger file to get meta
    trainer = Trainer(
        default_save_path=tmpdir,
        max_epochs=1,
    )

    lrfinder = trainer.lr_find(model, mode='linear')
    after_lr = lrfinder.suggestion()
    model.hparams.learning_rate = after_lr
    trainer.fit(model)

    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder'
Exemplo n.º 3
0
def test_custom_logger(tmpdir):
    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    logger = CustomLogger()

    trainer_options = dict(
        max_epochs=1,
        train_percent_check=0.05,
        logger=logger,
        default_root_dir=tmpdir
    )

    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result == 1, "Training failed"
    assert logger.hparams_logged == hparams
    assert logger.metrics_logged != {}
    assert logger.finalized_status == "success"
Exemplo n.º 4
0
def test_configure_optimizer_from_dict(tmpdir):
    """Tests if `configure_optimizer` method could return a dictionary with
    `optimizer` field only.
    """
    class CurrentTestModel(LightTrainDataloader, TestModelBase):
        def configure_optimizers(self):
            config = {
                'optimizer': torch.optim.SGD(params=self.parameters(),
                                             lr=1e-03)
            }
            return config

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    trainer = Trainer(default_save_path=tmpdir, max_epochs=1)
    result = trainer.fit(model)
    assert result == 1
Exemplo n.º 5
0
def test_amp_gpu_ddp(tmpdir):
    """Make sure DDP + AMP work."""
    if not tutils.can_run_gpu_test():
        return

    tutils.reset_seed()
    tutils.set_random_master_port()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(default_save_path=tmpdir,
                           show_progress_bar=True,
                           max_epochs=1,
                           gpus=2,
                           distributed_backend='ddp',
                           precision=16)

    tutils.run_model_test(trainer_options, model)
Exemplo n.º 6
0
def test_testtube_logger(tmpdir):
    """Verify that basic functionality of test tube logger works."""
    tutils.reset_seed()
    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    logger = tutils.get_default_testtube_logger(tmpdir, False)

    assert logger.name == 'lightning_logs'

    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           train_percent_check=0.05,
                           logger=logger)

    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    assert result == 1, 'Training failed'
def test_dataloader_config_errors(tmpdir, dataloader_options):
    class CurrentTestModel(
            LightTrainDataloader,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        **dataloader_options,
    )

    with pytest.raises(ValueError):
        trainer.fit(model)
Exemplo n.º 8
0
def test_trainer_arg_bool(tmpdir):
    class CurrentTestModel(
            LightTrainDataloader,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)
    before_lr = hparams.learning_rate
    # logger file to get meta
    trainer = Trainer(default_save_path=tmpdir,
                      max_epochs=1,
                      auto_lr_find=True)

    trainer.fit(model)
    after_lr = model.hparams.learning_rate
    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder'
Exemplo n.º 9
0
def test_early_stopping_cpu_model(tmpdir):
    """Test each of the trainer options."""
    stopping = EarlyStopping(monitor='val_loss', min_delta=0.1)
    trainer_options = dict(
        default_root_dir=tmpdir,
        early_stop_callback=stopping,
        gradient_clip_val=1.0,
        overfit_pct=0.20,
        track_grad_norm=2,
        train_percent_check=0.1,
        val_percent_check=0.1,
    )

    model = EvalModelTemplate(tutils.get_default_hparams())
    tutils.run_model_test(trainer_options, model, on_gpu=False)

    # test freeze on cpu
    model.freeze()
    model.unfreeze()
Exemplo n.º 10
0
def test_none_optimizer(tmpdir):
    tutils.reset_seed()

    class CurrentTestModel(LightTestNoneOptimizerMixin, LightTrainDataloader,
                           TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.2)
    result = trainer.fit(model)

    # verify training completed
    assert result == 1
Exemplo n.º 11
0
def test_error_on_more_than_1_optimizer(tmpdir):
    ''' Check that error is thrown when more than 1 optimizer is passed '''
    tutils.reset_seed()

    class CurrentTestModel(
            LightTestMultipleOptimizersWithSchedulingMixin,
            LightTrainDataloader,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer = Trainer(default_save_path=tmpdir, max_epochs=1)

    with pytest.raises(MisconfigurationException):
        trainer.lr_find(model)
Exemplo n.º 12
0
def test_testtube_pickle(tmpdir):
    """Verify that pickling a trainer containing a test tube logger works."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()

    logger = tutils.get_default_testtube_logger(tmpdir, False)
    logger.log_hyperparams(hparams)
    logger.save()

    trainer_options = dict(default_root_dir=tmpdir,
                           max_epochs=1,
                           train_percent_check=0.05,
                           logger=logger)

    trainer = Trainer(**trainer_options)
    pkl_bytes = pickle.dumps(trainer)
    trainer2 = pickle.loads(pkl_bytes)
    trainer2.logger.log_metrics({'acc': 1.0})
Exemplo n.º 13
0
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
    """Make sure DDP works with dataloaders passed to fit()"""
    tutils.set_random_master_port()

    trainer_options = dict(default_root_dir=tmpdir,
                           progress_bar_refresh_rate=0,
                           max_epochs=1,
                           train_percent_check=0.4,
                           val_percent_check=0.2,
                           gpus=[0, 1],
                           distributed_backend='ddp')

    model = EvalModelTemplate(tutils.get_default_hparams())
    fit_options = dict(train_dataloader=model.train_dataloader(),
                       val_dataloaders=model.val_dataloader())

    trainer = Trainer(**trainer_options)
    result = trainer.fit(model, **fit_options)
    assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
Exemplo n.º 14
0
def test_neptune_leave_open_experiment_after_fit(tmpdir):
    """Verify that neptune experiment was closed after training"""
    model = LightningTestModel(tutils.get_default_hparams())

    def _run_training(logger):
        logger._experiment = MagicMock()
        trainer = Trainer(default_root_dir=tmpdir,
                          max_epochs=1,
                          train_percent_check=0.05,
                          logger=logger)
        trainer.fit(model)
        return logger

    logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
    assert logger_close_after_fit._experiment.stop.call_count == 1

    logger_open_after_fit = _run_training(
        NeptuneLogger(offline_mode=True, close_after_fit=False))
    assert logger_open_after_fit._experiment.stop.call_count == 0
def test_trainer_arg(tmpdir, scale_arg):
    """ Check that trainer arg works with bool input. """
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = EvalModelTemplate(hparams)

    before_batch_size = hparams.batch_size
    # logger file to get meta
    trainer = Trainer(
        default_save_path=tmpdir,
        max_epochs=1,
        auto_scale_batch_size=scale_arg,
    )

    trainer.fit(model)
    after_batch_size = model.hparams.batch_size
    assert before_batch_size != after_batch_size, \
        'Batch size was not altered after running auto scaling of batch size'
def test_warning_on_wrong_validation_settings(tmpdir):
    """ Test the following cases related to validation configuration of model:
        * error if `val_dataloader()` is overriden but `validation_step()` is not
        * if both `val_dataloader()` and `validation_step()` is overriden,
            throw warning if `val_epoch_end()` is not defined
        * error if `validation_step()` is overriden but `val_dataloader()` is not
    """
    tutils.reset_seed()
    hparams = tutils.get_default_hparams()

    trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
    trainer = Trainer(**trainer_options)

    class CurrentTestModel(LightTrainDataloader,
                           LightValidationDataloader,
                           TestModelBase):
        pass

    # check val_dataloader -> val_step
    with pytest.raises(MisconfigurationException):
        model = CurrentTestModel(hparams)
        trainer.fit(model)

    class CurrentTestModel(LightTrainDataloader,
                           LightValidationStepMixin,
                           TestModelBase):
        pass

    # check val_dataloader + val_step -> val_epoch_end
    with pytest.warns(RuntimeWarning):
        model = CurrentTestModel(hparams)
        trainer.fit(model)

    class CurrentTestModel(LightTrainDataloader,
                           LightValStepFitSingleDataloaderMixin,
                           TestModelBase):
        pass

    # check val_step -> val_dataloader
    with pytest.raises(MisconfigurationException):
        model = CurrentTestModel(hparams)
        trainer.fit(model)
Exemplo n.º 17
0
def test_load_model_from_checkpoint(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(
        show_progress_bar=False,
        max_epochs=2,
        train_percent_check=0.4,
        val_percent_check=0.2,
        checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
        logger=False,
        default_save_path=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    trainer.test()

    # correct result and ok accuracy
    assert result == 1, 'training failed to complete'

    # load last checkpoint
    last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
    pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)

    # test that hparams loaded correctly
    for k, v in vars(hparams).items():
        assert getattr(pretrained_model.hparams, k) == v

    # assert weights are the same
    for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()):
        assert torch.all(torch.eq(old_p, new_p)), 'loaded weights are not the same as the saved weights'

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer)
Exemplo n.º 18
0
def test_multi_optimizer_with_scheduling(tmpdir):
    """ Verify that learning rate scheduling is working """
    tutils.reset_seed()

    class CurrentTestModel(
            LightTestMultipleOptimizersWithSchedulingMixin,
            LightTrainDataloader,
            TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(
        default_save_path=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2
    )

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model)

    init_lr = hparams.learning_rate
    adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups]
    adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups]

    assert len(trainer.lr_schedulers) == 2, \
        'all lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers)

    assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \
        'Lr not equally adjusted for all param groups for optimizer 1'
    adjusted_lr1 = adjusted_lr1[0]

    assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \
        'Lr not equally adjusted for all param groups for optimizer 2'
    adjusted_lr2 = adjusted_lr2[0]

    assert init_lr * 0.1 == adjusted_lr1 and init_lr * 0.1 == adjusted_lr2, \
        'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr1)
Exemplo n.º 19
0
def test_trainer_arg_str(tmpdir):
    class CurrentTestModel(
            LightTrainDataloader,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    hparams.__dict__['my_fancy_lr'] = 1.0  # update with non-standard field
    model = CurrentTestModel(hparams)
    before_lr = hparams.my_fancy_lr
    # logger file to get meta
    trainer = Trainer(default_save_path=tmpdir,
                      max_epochs=1,
                      auto_lr_find='my_fancy_lr')

    trainer.fit(model)
    after_lr = model.hparams.my_fancy_lr
    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder'
def test_call_to_trainer_method(tmpdir, scale_method):
    """ Test that calling the trainer method itself works. """
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = EvalModelTemplate(hparams)

    before_batch_size = hparams.batch_size
    # logger file to get meta
    trainer = Trainer(
        default_save_path=tmpdir,
        max_epochs=1,
    )

    after_batch_size = trainer.scale_batch_size(model, mode=scale_method, max_trials=5)
    model.hparams.batch_size = after_batch_size
    trainer.fit(model)

    assert before_batch_size != after_batch_size, \
        'Batch size was not altered after running auto scaling of batch size'
Exemplo n.º 21
0
def test_on_before_zero_grad_called(max_steps):
    class CurrentTestModel(EvalModelTemplate):
        on_before_zero_grad_called = 0

        def on_before_zero_grad(self, optimizer):
            self.on_before_zero_grad_called += 1

    model = CurrentTestModel(tutils.get_default_hparams())

    trainer = Trainer(
        max_steps=max_steps,
        num_sanity_val_steps=5,
    )
    assert 0 == model.on_before_zero_grad_called
    trainer.fit(model)
    assert max_steps == model.on_before_zero_grad_called

    model.on_before_zero_grad_called = 0
    trainer.test(model)
    assert 0 == model.on_before_zero_grad_called
def test_multi_optimizer_with_scheduling_stepping(tmpdir):
    tutils.reset_seed()

    class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin,
                           LightTrainDataloader, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           val_percent_check=0.1,
                           train_percent_check=0.2)

    # fit model
    trainer = Trainer(**trainer_options)
    results = trainer.fit(model)

    init_lr = hparams.learning_rate
    adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups]
    adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups]

    assert len(trainer.lr_schedulers) == 2, \
        'all lr scheduler not initialized properly'

    assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \
        'lr not equally adjusted for all param groups for optimizer 1'
    adjusted_lr1 = adjusted_lr1[0]

    assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \
        'lr not equally adjusted for all param groups for optimizer 2'
    adjusted_lr2 = adjusted_lr2[0]

    # Called ones after end of epoch
    assert init_lr * (0.1)**3 == adjusted_lr1, \
        'lr for optimizer 1 not adjusted correctly'
    # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times
    assert init_lr * 0.1 == adjusted_lr2, \
        'lr for optimizer 2 not adjusted correctly'
Exemplo n.º 23
0
def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
    """Verify that basic functionality of all loggers."""
    tutils.reset_seed()

    # prevent comet logger from trying to print at exit, since
    # pytest's stdout/stderr redirection breaks it
    import atexit
    monkeypatch.setattr(atexit, 'register', lambda _: None)

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    class StoreHistoryLogger(logger_class):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.history = []

        def log_metrics(self, metrics, step):
            super().log_metrics(metrics, step)
            self.history.append((step, metrics))

    if 'save_dir' in inspect.getfullargspec(logger_class).args:
        logger = StoreHistoryLogger(save_dir=str(tmpdir))
    else:
        logger = StoreHistoryLogger()

    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        train_percent_check=0.2,
        val_percent_check=0.5,
        fast_dev_run=True,
    )
    trainer.fit(model)

    trainer.test()

    log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
    assert log_metric_names == [(0, ['val_acc', 'val_loss']),
                                (0, ['train_some_val']),
                                (1, ['test_acc', 'test_loss'])]
Exemplo n.º 24
0
def test_inf_val_dataloader(tmpdir):
    """Test inf val data loader (e.g. IterableDataset)"""
    class CurrentTestModel(LightInfValDataloader, LightningTestModel):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    with pytest.raises(MisconfigurationException):
        trainer = Trainer(default_root_dir=tmpdir,
                          max_epochs=1,
                          val_percent_check=0.5)
        trainer.fit(model)

    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
    result = trainer.fit(model)

    # verify training completed
    assert result == 1
Exemplo n.º 25
0
def test_multiple_loggers(tmpdir):
    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    logger1 = CustomLogger()
    logger2 = CustomLogger()

    trainer = Trainer(max_epochs=1,
                      train_percent_check=0.05,
                      logger=[logger1, logger2],
                      default_root_dir=tmpdir)
    result = trainer.fit(model)
    assert result == 1, "Training failed"

    assert logger1.hparams_logged == hparams
    assert logger1.metrics_logged != {}
    assert logger1.finalized_status == "success"

    assert logger2.hparams_logged == hparams
    assert logger2.metrics_logged != {}
    assert logger2.finalized_status == "success"
def test_lr_logger_multi_lrs(tmpdir):
    """ Test that learning rates are extracted and logged for multi lr schedulers """
    tutils.reset_seed()

    model = EvalModelTemplate(tutils.get_default_hparams())
    model.configure_optimizers = model.configure_optimizers__multiple_schedulers

    lr_logger = LearningRateLogger()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.5,
                      callbacks=[lr_logger])
    results = trainer.fit(model)

    assert results == 1
    assert lr_logger.lrs, 'No learning rates logged'
    assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
        'Number of learning rates logged does not match number of lr schedulers'
    assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
        'Names of learning rates not set correctly'
Exemplo n.º 27
0
def test_trains_logger(tmpdir):
    """Verify that basic functionality of TRAINS logger works."""
    model = EvalModelTemplate(tutils.get_default_hparams())
    TrainsLogger.set_bypass_mode(True)
    TrainsLogger.set_credentials(
        api_host='http://integration.trains.allegro.ai:8008',
        files_host='http://integration.trains.allegro.ai:8081',
        web_host='http://integration.trains.allegro.ai:8080',
    )
    logger = TrainsLogger(project_name="lightning_log",
                          task_name="pytorch lightning test")

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      train_percent_check=0.05,
                      logger=logger)
    result = trainer.fit(model)

    print('result finished')
    logger.finalize()
    assert result == 1, "Training failed"
Exemplo n.º 28
0
def test_reduce_lr_on_plateau_scheduling(tmpdir):
    class CurrentTestModel(LightTestReduceLROnPlateauMixin,
                           LightTrainDataloader, LightValidationMixin,
                           LightValidationStepMixin, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    # fit model
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      val_percent_check=0.1,
                      train_percent_check=0.2)
    results = trainer.fit(model)
    assert results

    assert trainer.lr_schedulers[0] == \
        dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss',
             interval='epoch', frequency=1, reduce_on_plateau=True), \
        'lr schduler was not correctly converted to dict'
Exemplo n.º 29
0
def test_model_checkpoint_with_non_string_input(tmpdir):
    """ Test that None in checkpoint callback is valid and that chkp_path is
        set correctly """
    tutils.reset_seed()

    class CurrentTestModel(LightTrainDataloader, TestModelBase):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    checkpoint = ModelCheckpoint(filepath=None, save_top_k=-1)

    trainer = Trainer(default_root_dir=tmpdir,
                      checkpoint_callback=checkpoint,
                      overfit_pct=0.20,
                      max_epochs=5)
    result = trainer.fit(model)

    # These should be different if the dirpath has be overridden
    assert trainer.ckpt_path != trainer.default_root_dir
Exemplo n.º 30
0
def test_simple_cpu(tmpdir):
    """Verify continue training session on CPU."""
    tutils.reset_seed()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)

    # logger file to get meta
    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.1,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'