Ejemplo n.º 1
0
def test_train_loop_only(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    model.validation_step = None
    model.validation_step_end = None
    model.validation_epoch_end = None
    model.test_step = None
    model.test_step_end = None
    model.test_epoch_end = None

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      enable_model_summary=False)

    # fit model
    trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert trainer.callback_metrics["train_loss"] < 1.0
Ejemplo n.º 2
0
def test_multi_gpu_early_stop_dp(tmpdir):
    """Make sure DDP works.

    with early stopping
    """
    tutils.set_random_master_port()

    dm = ClassifDataModule()
    model = CustomClassificationModelDP()

    trainer_options = dict(
        default_root_dir=tmpdir,
        callbacks=[EarlyStopping(monitor="val_acc")],
        max_epochs=50,
        limit_train_batches=10,
        limit_val_batches=10,
        gpus=[0, 1],
        accelerator="dp",
    )

    tpipes.run_model_test(trainer_options, model, dm)
Ejemplo n.º 3
0
def test_train_val_loop_only(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    model.validation_step = None
    model.validation_step_end = None
    model.validation_epoch_end = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        weights_summary=None,
    )

    # fit model
    result = trainer.fit(model, datamodule=dm)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert result
    assert trainer.callback_metrics['train_loss'] < 1.0
Ejemplo n.º 4
0
def test_full_loop(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        weights_summary=None,
        deterministic=True,
    )

    # fit model
    result = trainer.fit(model, dm)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert result

    # test
    result = trainer.test(datamodule=dm)
    assert result[0]['test_acc'] > 0.6
def run_checkpoint_test(tmpdir, save_full_weights):
    seed_everything(1)
    model = ModelParallelClassificationModel()
    dm = ClassifDataModule()
    ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=10,
        plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
        gpus=2,
        precision=16,
        accumulate_grad_batches=2,
        callbacks=[ck]
    )
    trainer.fit(model, datamodule=dm)

    results = trainer.test(model, datamodule=dm)
    assert results[0]['test_acc'] > 0.7

    saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
    assert saved_results[0]['test_acc'] > 0.7
    assert saved_results == results

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=10,
        plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
        gpus=2,
        precision=16,
        accumulate_grad_batches=2,
        callbacks=[ck],
        resume_from_checkpoint=ck.best_model_path
    )
    results = trainer.test(model, datamodule=dm)
    assert results[0]['test_acc'] > 0.7

    dm.predict_dataloader = dm.test_dataloader
    results = trainer.predict(datamodule=dm)
    assert results[-1] > 0.7
def test_resume_early_stopping_from_checkpoint(tmpdir):
    """
    Prevent regressions to bugs:
    https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
    https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
    """
    seed_everything(42)
    model = ClassificationModel()
    dm = ClassifDataModule()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1)
    early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss")
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback, checkpoint_callback],
        num_sanity_val_steps=0,
        max_epochs=4,
    )
    trainer.fit(model, datamodule=dm)

    assert len(early_stop_callback.saved_states) == 4

    checkpoint_filepath = checkpoint_callback.kth_best_model_path
    # ensure state is persisted properly
    checkpoint = torch.load(checkpoint_filepath)
    # the checkpoint saves "epoch + 1"
    early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
    assert 4 == len(early_stop_callback.saved_states)
    assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state

    # ensure state is reloaded properly (assertion in the callback)
    early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
    new_trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        resume_from_checkpoint=checkpoint_filepath,
        callbacks=[early_stop_callback],
    )

    with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
        new_trainer.fit(model)
Ejemplo n.º 7
0
def test_try_resume_from_non_existing_checkpoint(tmpdir):
    """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    checkpoint_cb = ModelCheckpoint(dirpath=tmpdir,
                                    monitor="val_loss",
                                    save_last=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=False,
        callbacks=[checkpoint_cb],
        limit_train_batches=2,
        limit_val_batches=2,
    )
    # Generate checkpoint `last.ckpt` with BoringModel
    trainer.fit(model, datamodule=dm)
    # `True` if resume/restore successfully else `False`
    assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"),
                                                trainer.on_gpu)
    assert not trainer.checkpoint_connector.restore(
        str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
Ejemplo n.º 8
0
def test_lr_monitor_param_groups(tmpdir):
    """ Test that learning rates are extracted and logged for single lr scheduler. """
    tutils.reset_seed()

    class CustomClassificationModel(ClassificationModel):

        def configure_optimizers(self):
            param_groups = [{
                'params': list(self.parameters())[:2],
                'lr': self.lr * 0.1
            }, {
                'params': list(self.parameters())[2:],
                'lr': self.lr
            }]

            optimizer = optim.Adam(param_groups)
            lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
            return [optimizer], [lr_scheduler]

    model = CustomClassificationModel()
    dm = ClassifDataModule()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_val_batches=0.1,
        limit_train_batches=0.5,
        callbacks=[lr_monitor],
    )
    trainer.fit(model, datamodule=dm)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    assert lr_monitor.lrs, 'No learning rates logged'
    assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \
        'Number of learning rates logged does not match number of param groups'
    assert lr_monitor.lr_sch_names == ['lr-Adam']
    assert list(lr_monitor.lrs.keys()) == ['lr-Adam/pg1', 'lr-Adam/pg2'], \
        'Names of learning rates not set correctly'
Ejemplo n.º 9
0
def test_optimization(tmpdir):
    seed_everything(42)

    dm = ClassifDataModule(length=1024)
    model = IPUClassificationModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        ipus=2,
    )

    # fit model
    trainer.fit(model, dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # validate
    result = trainer.validate(datamodule=dm)
    assert dm.trainer is not None
    assert result[0]['val_acc'] > 0.7

    # test
    result = trainer.test(model, datamodule=dm)
    assert dm.trainer is not None
    test_result = result[0]['test_acc']
    assert test_result > 0.6

    # test saved model
    model_path = os.path.join(tmpdir, 'model.pt')
    trainer.save_checkpoint(model_path)

    model = IPUClassificationModel.load_from_checkpoint(model_path)

    trainer = Trainer(default_root_dir=tmpdir, ipus=2)

    result = trainer.test(model, datamodule=dm)
    saved_result = result[0]['test_acc']
    assert saved_result == test_result
Ejemplo n.º 10
0
def test_evaluate(tmpdir, trainer_kwargs):
    tutils.set_random_main_port()
    seed_everything(1)
    dm = ClassifDataModule()
    model = CustomClassificationModelDP()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_train_batches=10,
                      limit_val_batches=10,
                      **trainer_kwargs)

    trainer.fit(model, datamodule=dm)
    assert "ckpt" in trainer.checkpoint_callback.best_model_path

    old_weights = model.layer_0.weight.clone().detach().cpu()

    trainer.validate(datamodule=dm)
    trainer.test(datamodule=dm)

    # make sure weights didn't change
    new_weights = model.layer_0.weight.clone().detach().cpu()
    torch.testing.assert_allclose(old_weights, new_weights)
Ejemplo n.º 11
0
def test_early_stopping_no_val_step(tmpdir):
    """Test that early stopping callback falls back to training metrics when no validation defined."""

    model = ClassificationModel()
    dm = ClassifDataModule()
    model.validation_step = None
    model.val_dataloader = None

    stopping = EarlyStopping(monitor='train_loss',
                             min_delta=0.1,
                             patience=0,
                             check_on_train_epoch_end=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[stopping],
        overfit_batches=0.20,
        max_epochs=10,
    )
    trainer.fit(model, datamodule=dm)

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert trainer.current_epoch < trainer.max_epochs - 1
Ejemplo n.º 12
0
def test_callbacks_state_resume_from_checkpoint(tmpdir):
    """ Test that resuming from a checkpoint restores callbacks that persist state. """
    dm = ClassifDataModule()
    model = ClassificationModel()
    callback_capture = CaptureCallbacksBeforeTraining()

    def get_trainer_args():
        checkpoint = ModelCheckpoint(dirpath=tmpdir,
                                     monitor="val_loss",
                                     save_last=True)
        trainer_args = dict(default_root_dir=tmpdir,
                            max_steps=1,
                            logger=False,
                            callbacks=[
                                checkpoint,
                                callback_capture,
                            ])
        assert checkpoint.best_model_path == ""
        assert checkpoint.best_model_score is None
        return trainer_args

    # initial training
    trainer = Trainer(**get_trainer_args())
    trainer.fit(model, datamodule=dm)
    callbacks_before_resume = deepcopy(trainer.callbacks)

    # resumed training
    trainer = Trainer(**get_trainer_args(),
                      resume_from_checkpoint=str(tmpdir / "last.ckpt"))
    trainer.fit(model, datamodule=dm)

    assert len(callbacks_before_resume) == len(callback_capture.callbacks)

    for before, after in zip(callbacks_before_resume,
                             callback_capture.callbacks):
        if isinstance(before, ModelCheckpoint):
            assert before.best_model_path == after.best_model_path
            assert before.best_model_score == after.best_model_score
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(
        tmpdir, offload_optimizer):
    """Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works."""
    seed_everything(42)

    class VerificationCallback(Callback):
        def __init__(self):
            self.on_train_batch_start_called = False

        def on_train_batch_start(self, trainer, pl_module: LightningModule,
                                 batch: Any, batch_idx: int) -> None:
            deepspeed_engine = trainer.strategy.model
            assert trainer.global_step == deepspeed_engine.global_steps
            self.on_train_batch_start_called = True

    model = ModelParallelClassificationModel()
    dm = ClassifDataModule()
    verification_callback = VerificationCallback()
    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        # TODO: this test fails with max_epochs >1 as there are leftover batches per epoch.
        # there's divergence in how Lightning handles the last batch of the epoch with how DeepSpeed does it.
        # we step the optimizers on the last batch but DeepSpeed keeps the accumulation for the next epoch
        max_epochs=1,
        strategy=DeepSpeedStrategy(stage=2,
                                   offload_optimizer=offload_optimizer),
        accelerator="gpu",
        devices=2,
        limit_train_batches=5,
        limit_val_batches=2,
        precision=16,
        accumulate_grad_batches=2,
        callbacks=[verification_callback],
    )
    assert trainer.limit_train_batches % trainer.accumulate_grad_batches != 0, "leftover batches should be tested"
    trainer.fit(model, datamodule=dm)
    assert verification_callback.on_train_batch_start_called
Ejemplo n.º 14
0
def test_lr_monitor_param_groups(tmpdir):
    """Test that learning rates are extracted and logged for single lr scheduler."""
    tutils.reset_seed()

    class CustomClassificationModel(ClassificationModel):
        def configure_optimizers(self):
            param_groups = [
                {
                    "params": list(self.parameters())[:2],
                    "lr": self.lr * 0.1
                },
                {
                    "params": list(self.parameters())[2:],
                    "lr": self.lr
                },
            ]

            optimizer = optim.Adam(param_groups)
            lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
            return [optimizer], [lr_scheduler]

    model = CustomClassificationModel()
    dm = ClassifDataModule()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_val_batches=0.1,
                      limit_train_batches=0.5,
                      callbacks=[lr_monitor])
    trainer.fit(model, datamodule=dm)

    assert lr_monitor.lrs, "No learning rates logged"
    assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs)
    assert list(lr_monitor.lrs) == [
        "lr-Adam/pg1", "lr-Adam/pg2"
    ], "Names of learning rates not set correctly"
Ejemplo n.º 15
0
def test_running_test_pretrained_model_cpu(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45)
def _deepspeed_multigpu_stage_2_accumulated_grad_batches(
        tmpdir, offload_optimizer):
    """
    Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.
    """
    seed_everything(42)

    class VerificationCallback(Callback):
        def __init__(self):
            self.on_train_batch_start_called = False

        def on_train_batch_start(self, trainer, pl_module: LightningModule,
                                 batch: Any, batch_idx: int,
                                 dataloader_idx: int) -> None:
            deepspeed_engine = trainer.training_type_plugin.model
            assert trainer.global_step == deepspeed_engine.global_steps
            self.on_train_batch_start_called = True

    model = ModelParallelClassificationModel()
    dm = ClassifDataModule()
    verification_callback = VerificationCallback()
    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=5,
        plugins=[
            DeepSpeedPlugin(stage=2, offload_optimizer=offload_optimizer)
        ],
        gpus=2,
        limit_val_batches=2,
        precision=16,
        accumulate_grad_batches=2,
        callbacks=[verification_callback],
    )
    trainer.fit(model, datamodule=dm)
    assert verification_callback.on_train_batch_start_called
Ejemplo n.º 17
0
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
    """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the
    optimizer state and scheduler states cannot be restored."""
    dm = ClassifDataModule()
    model = BoringModel()
    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)
    trainer.save_checkpoint(checkpoint_path)

    trainer = Trainer(
        default_root_dir=tmpdir,
        fast_dev_run=True,
        strategy=DeepSpeedStrategy(stage=3, load_full_weights=True),
        gpus=1,
        precision=16,
    )
    with pytest.warns(
        UserWarning,
        match="A single checkpoint file has been given. This means optimizer states cannot be restored. "
        "If you'd like to restore these states, you must "
        "provide a path to the originally saved DeepSpeed checkpoint.",
    ):
        trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_path)
Ejemplo n.º 18
0
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_master_port()

    class CustomClassificationModelDP(ClassificationModel):
        def _step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            return {'logits': logits, 'y': y}

        def training_step(self, batch, batch_idx):
            _, y = batch
            out = self._step(batch, batch_idx)
            loss = F.cross_entropy(out['logits'], y)
            return loss

        def validation_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def test_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def validation_step_end(self, outputs):
            self.log('val_acc', self.valid_acc(outputs['logits'],
                                               outputs['y']))

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='dp',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction(pretrained_model, dataloader)
Ejemplo n.º 19
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           accelerator="gpu",
                           devices=2,
                           strategy="dp",
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir)
    trainer.save_checkpoint(hpc_save_path)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        def on_validation_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
            dataloader = dm.val_dataloader()
            tpipes.run_model_prediction(self.trainer.lightning_module,
                                        dataloader=dataloader)

    # new model
    model = CustomModel()

    # validate new model which should load hpc weights
    new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Ejemplo n.º 20
0
def test_trainer_properties_restore_ckpt_path(tmpdir):
    """Test that required trainer properties are set correctly when resuming from checkpoint in different
    phases."""
    class CustomClassifModel(ClassificationModel):
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                           step_size=1)
            return [optimizer], [lr_scheduler]

    model = CustomClassifModel()
    dm = ClassifDataModule()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True)
    trainer_args = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=False,
        callbacks=[checkpoint_callback],
        num_sanity_val_steps=0,
    )
    trainer = Trainer(**trainer_args)
    trainer.fit(model, datamodule=dm)

    resume_ckpt = str(tmpdir / "last.ckpt")
    state_dict = torch.load(resume_ckpt)

    trainer_args.update({
        "max_epochs": 3,
        "enable_checkpointing": False,
        "callbacks": []
    })

    class CustomClassifModel(CustomClassifModel):
        def _is_equal(self, a, b):
            if isinstance(a, torch.Tensor):
                return torch.all(torch.eq(a, b))

            if isinstance(a, Mapping):
                return all(
                    self._is_equal(a.get(k, None), b.get(k, None))
                    for k in b.keys())

            return a == b

        def _check_optimizers(self):
            return all(
                self._is_equal(self.trainer.optimizers[i].state_dict(),
                               state_dict["optimizer_states"][i])
                for i in range(len(self.trainer.optimizers)))

        def _check_schedulers(self):
            return all(
                self._is_equal(
                    self.trainer.lr_schedulers[i]["scheduler"].state_dict(),
                    state_dict["lr_schedulers"][i])
                for i in range(len(self.trainer.lr_schedulers)))

        def _check_model_state_dict(self):
            for k in self.state_dict():
                yield self._is_equal(self.state_dict()[k],
                                     state_dict["state_dict"][k])

        def _test_on_val_test_predict_tune_start(self):
            assert self.trainer.current_epoch == state_dict["epoch"]
            assert self.trainer.global_step == state_dict["global_step"]
            assert all(self._check_model_state_dict())

            # no optimizes and schedulers are loaded otherwise
            if self.trainer.state.fn != TrainerFn.TUNING:
                return

            assert not self._check_optimizers()
            assert not self._check_schedulers()

        def on_train_start(self):
            if self.trainer.state.fn == TrainerFn.TUNING:
                self._test_on_val_test_predict_tune_start()
            else:
                assert self.trainer.current_epoch == state_dict["epoch"]
                assert self.trainer.global_step == state_dict["global_step"]
                assert all(self._check_model_state_dict())
                assert self._check_optimizers()
                assert self._check_schedulers()

        def on_validation_start(self):
            if self.trainer.state.fn == TrainerFn.VALIDATING:
                self._test_on_val_test_predict_tune_start()

        def on_test_start(self):
            self._test_on_val_test_predict_tune_start()

    for fn in ("fit", "validate", "test", "predict"):
        model = CustomClassifModel()
        dm = ClassifDataModule()
        trainer_args["auto_scale_batch_size"] = (fn == "tune", )
        trainer = Trainer(**trainer_args)
        trainer_fn = getattr(trainer, fn)
        trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
def test_deepspeed_multigpu_stage_3_resume_training(tmpdir):
    """Test to ensure with Stage 3 and single GPU that we can resume training."""
    initial_model = ModelParallelClassificationModel()
    dm = ClassifDataModule()

    ck = ModelCheckpoint(monitor="val_acc",
                         mode="max",
                         save_last=True,
                         save_top_k=-1)
    initial_trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        strategy=DeepSpeedStrategy(stage=3),
        accelerator="gpu",
        devices=1,
        precision=16,
        callbacks=[ck],
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    initial_trainer.fit(initial_model, datamodule=dm)

    class TestCallback(Callback):
        def on_train_batch_start(self, trainer: Trainer,
                                 pl_module: LightningModule, batch: Any,
                                 batch_idx: int) -> None:
            original_deepspeed_strategy = initial_trainer.strategy
            current_deepspeed_strategy = trainer.strategy

            assert isinstance(original_deepspeed_strategy, DeepSpeedStrategy)
            assert isinstance(current_deepspeed_strategy, DeepSpeedStrategy)
            # assert optimizer states are the correctly loaded
            original_optimizer_dict = original_deepspeed_strategy.deepspeed_engine.optimizer.state_dict(
            )
            current_optimizer_dict = current_deepspeed_strategy.deepspeed_engine.optimizer.state_dict(
            )
            for orig_tensor, current_tensor in zip(
                    original_optimizer_dict["fp32_flat_groups"],
                    current_optimizer_dict["fp32_flat_groups"]):
                assert torch.all(orig_tensor.eq(current_tensor))
            # assert model state is loaded correctly
            for current_param, initial_param in zip(
                    pl_module.parameters(), initial_model.parameters()):
                assert torch.equal(current_param.cpu(), initial_param.cpu())
            # assert epoch has correctly been restored
            assert trainer.current_epoch == 1

            # assert lr-scheduler states are loaded correctly
            original_lr_scheduler = initial_trainer.lr_scheduler_configs[
                0].scheduler
            current_lr_scheduler = trainer.lr_scheduler_configs[0].scheduler
            assert original_lr_scheduler.state_dict(
            ) == current_lr_scheduler.state_dict()

    model = ModelParallelClassificationModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        fast_dev_run=True,
        strategy=DeepSpeedStrategy(stage=3),
        accelerator="gpu",
        devices=1,
        precision=16,
        callbacks=TestCallback(),
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path)
Ejemplo n.º 22
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           gpus=2,
                           accelerator='dp',
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['callbacks'] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    trainer.fit(model, datamodule=dm)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        # set the epoch start hook so we can predict before the model does the full training
        def on_train_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            new_trainer._running_stage = RunningStage.EVALUATING

            dataloader = self.train_dataloader()
            tpipes.run_prediction_eval_model_template(
                self.trainer.lightning_module, dataloader=dataloader)
            self.on_train_start_called = True

    # new model
    model = CustomModel()

    # fit new model which should load hpc weights
    new_trainer.fit(model, datamodule=dm)
    assert model.on_train_start_called

    # test freeze on gpu
    model.freeze()
    model.unfreeze()