def test_optimization(tmpdir):
    seed_everything(42)

    dm = ClassifDataModule(length=1024)
    model = ClassificationModel()

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="hpu", devices=1)

    # fit model
    trainer.fit(model, dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # validate
    result = trainer.validate(datamodule=dm)
    assert dm.trainer is not None
    assert result[0]["val_acc"] > 0.7

    # test
    result = trainer.test(model, datamodule=dm)
    assert dm.trainer is not None
    test_result = result[0]["test_acc"]
    assert test_result > 0.6

    # test saved model
    model_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(model_path)

    model = ClassificationModel.load_from_checkpoint(model_path)

    trainer = Trainer(default_root_dir=tmpdir, accelerator="hpu", devices=1)

    result = trainer.test(model, datamodule=dm)
    saved_result = result[0]["test_acc"]
    assert saved_result == test_result
def main():
    seed_everything(4321)

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument("--trainer_method", default="fit")
    parser.add_argument("--tmpdir")
    parser.add_argument("--workdir")
    parser.set_defaults(accelerator="gpu", devices=2)
    parser.set_defaults(strategy="ddp")
    args = parser.parse_args()

    dm = ClassifDataModule()
    model = ClassificationModel()
    trainer = Trainer.from_argparse_args(args)

    if args.trainer_method == "fit":
        trainer.fit(model, datamodule=dm)
        result = None
    elif args.trainer_method == "test":
        result = trainer.test(model, datamodule=dm)
    elif args.trainer_method == "fit_test":
        trainer.fit(model, datamodule=dm)
        result = trainer.test(model, datamodule=dm)
    else:
        raise ValueError(f"Unsupported: {args.trainer_method}")

    result_ext = {
        "status": "complete",
        "method": args.trainer_method,
        "result": result
    }
    file_path = os.path.join(args.tmpdir, "ddp.result")
    torch.save(result_ext, file_path)
Exemple #3
0
def test_callbacks_references_fit_ckpt_path(tmpdir):
    """Test that resuming from a checkpoint sets references as expected."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    args = {
        "default_root_dir": tmpdir,
        "max_steps": 1,
        "logger": False,
        "limit_val_batches": 2,
        "num_sanity_val_steps": 0,
    }

    # initial training
    checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
    trainer = Trainer(**args, callbacks=[checkpoint])
    assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback
    trainer.fit(model, datamodule=dm)

    # resumed training
    new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
    # pass in a new checkpoint object, which should take
    # precedence over the one in the last.ckpt file
    trainer = Trainer(**args, callbacks=[new_checkpoint])
    assert checkpoint is not new_checkpoint
    assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback
    trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))
Exemple #4
0
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
    """Verify `test()` on pretrained model."""
    tutils.set_random_main_port()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        accelerator="gpu",
        devices=[0, 1],
        strategy="ddp_spawn",
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
def test_fit_csv_logger(tmpdir):
    dm = ClassifDataModule()
    model = ClassificationModel()
    logger = CSVLogger(save_dir=tmpdir)
    trainer = Trainer(default_root_dir=tmpdir,
                      max_steps=10,
                      logger=logger,
                      log_every_n_steps=1)
    trainer.fit(model, datamodule=dm)
    metrics_file = os.path.join(logger.log_dir,
                                ExperimentWriter.NAME_METRICS_FILE)
    assert os.path.isfile(metrics_file)
Exemple #6
0
def test_train_val_loop_only(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    model.validation_step = None
    model.validation_step_end = None
    model.validation_epoch_end = None

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False)

    # fit model
    trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert trainer.callback_metrics["train_loss"] < 1.0
Exemple #7
0
def test_multi_gpu_early_stop_ddp_spawn(tmpdir):
    tutils.set_random_main_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        callbacks=[EarlyStopping(monitor="train_acc")],
        max_epochs=50,
        limit_train_batches=10,
        limit_val_batches=10,
        accelerator="gpu",
        devices=[0, 1],
        strategy="ddp_spawn",
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, dm)
Exemple #8
0
def test_callbacks_state_fit_ckpt_path(tmpdir):
    """Test that resuming from a checkpoint restores callbacks that persist state."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    callback_capture = CaptureCallbacksBeforeTraining()

    def get_trainer_args():
        checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
        trainer_args = dict(
            default_root_dir=tmpdir,
            limit_train_batches=1,
            limit_val_batches=2,
            max_epochs=1,
            logger=False,
            callbacks=[checkpoint, callback_capture],
        )
        assert checkpoint.best_model_path == ""
        assert checkpoint.best_model_score is None
        return trainer_args

    # initial training
    trainer = Trainer(**get_trainer_args())
    with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"):
        trainer.fit(model, datamodule=dm)

    callbacks_before_resume = deepcopy(trainer.callbacks)

    # resumed training
    trainer = Trainer(**get_trainer_args())
    with pytest.deprecated_call(match="`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"):
        trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))

    assert len(callbacks_before_resume) == len(callback_capture.callbacks)

    for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
        if isinstance(before, ModelCheckpoint):
            for attribute in (
                "best_model_path",
                "best_model_score",
                "best_k_models",
                "kth_best_model_path",
                "kth_value",
                "last_model_path",
            ):
                assert getattr(before, attribute) == getattr(after, attribute)
def test_resume_early_stopping_from_checkpoint(tmpdir):
    """Prevent regressions to bugs:

    https://github.com/Lightning-AI/lightning/issues/1464
    https://github.com/Lightning-AI/lightning/issues/1463
    """
    seed_everything(42)
    model = ClassificationModel()
    dm = ClassifDataModule()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir,
                                          monitor="train_loss",
                                          save_top_k=1)
    early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss")
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback, checkpoint_callback],
        num_sanity_val_steps=0,
        max_epochs=4,
    )
    trainer.fit(model, datamodule=dm)

    assert len(early_stop_callback.saved_states) == 4

    checkpoint_filepath = checkpoint_callback.kth_best_model_path
    # ensure state is persisted properly
    checkpoint = torch.load(checkpoint_filepath)
    # the checkpoint saves "epoch + 1"
    early_stop_callback_state = early_stop_callback.saved_states[
        checkpoint["epoch"]]
    assert 4 == len(early_stop_callback.saved_states)
    es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
    assert checkpoint["callbacks"][es_name] == early_stop_callback_state

    # ensure state is reloaded properly (assertion in the callback)
    early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state,
                                                   monitor="train_loss")
    new_trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        callbacks=[early_stop_callback],
    )

    with pytest.raises(MisconfigurationException,
                       match=r"You restored a checkpoint with current_epoch"):
        new_trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_filepath)
Exemple #10
0
def test_multi_cpu_model_ddp(tmpdir):
    """Make sure DDP works."""
    tutils.set_random_main_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        accelerator="cpu",
        devices=2,
        strategy="ddp_spawn",
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, data=dm)
def test_early_stopping_no_val_step(tmpdir):
    """Test that early stopping callback falls back to training metrics when no validation defined."""

    model = ClassificationModel()
    dm = ClassifDataModule()
    model.validation_step = None
    model.val_dataloader = None

    stopping = EarlyStopping(monitor="train_loss",
                             min_delta=0.1,
                             patience=0,
                             check_on_train_epoch_end=True)
    trainer = Trainer(default_root_dir=tmpdir,
                      callbacks=[stopping],
                      overfit_batches=0.20,
                      max_epochs=10)
    trainer.fit(model, datamodule=dm)

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert trainer.current_epoch < trainer.max_epochs - 1
def test_early_stopping_no_extraneous_invocations(tmpdir):
    """Test to ensure that callback methods aren't being invoked outside of the callback handler."""
    model = ClassificationModel()
    dm = ClassifDataModule()
    early_stop_callback = EarlyStopping(monitor="train_loss")
    early_stop_callback._run_early_stopping_check = Mock()
    expected_count = 4
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback],
        limit_train_batches=4,
        limit_val_batches=4,
        max_epochs=expected_count,
        enable_checkpointing=False,
    )
    trainer.fit(model, datamodule=dm)

    assert trainer.early_stopping_callback == early_stop_callback
    assert trainer.early_stopping_callbacks == [early_stop_callback]
    assert early_stop_callback._run_early_stopping_check.call_count == expected_count
def test_evaluate(tmpdir, trainer_kwargs):
    tutils.set_random_main_port()
    seed_everything(1)
    dm = ClassifDataModule()
    model = CustomClassificationModelDP()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_train_batches=10,
                      limit_val_batches=10,
                      **trainer_kwargs)

    trainer.fit(model, datamodule=dm)
    assert "ckpt" in trainer.checkpoint_callback.best_model_path

    old_weights = model.layer_0.weight.clone().detach().cpu()

    trainer.validate(datamodule=dm)
    trainer.test(datamodule=dm)

    # make sure weights didn't change
    new_weights = model.layer_0.weight.clone().detach().cpu()
    torch_test_assert_close(old_weights, new_weights)
Exemple #14
0
def test_full_loop(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False, deterministic=True)

    # fit model
    trainer.fit(model, dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # validate
    result = trainer.validate(model, dm)
    assert dm.trainer is not None
    assert result[0]["val_acc"] > 0.7

    # test
    result = trainer.test(model, dm)
    assert dm.trainer is not None
    assert result[0]["test_acc"] > 0.6
def test_multi_gpu_early_stop_dp(tmpdir):
    """Make sure DDP works.

    with early stopping
    """
    tutils.set_random_main_port()

    dm = ClassifDataModule()
    model = CustomClassificationModelDP()

    trainer_options = dict(
        default_root_dir=tmpdir,
        callbacks=[EarlyStopping(monitor="val_acc")],
        max_epochs=50,
        limit_train_batches=10,
        limit_val_batches=10,
        accelerator="gpu",
        devices=[0, 1],
        strategy="dp",
    )

    tpipes.run_model_test(trainer_options, model, dm)
Exemple #16
0
def test_lr_monitor_param_groups(tmpdir):
    """Test that learning rates are extracted and logged for single lr scheduler."""
    tutils.reset_seed()

    class CustomClassificationModel(ClassificationModel):
        def configure_optimizers(self):
            param_groups = [
                {
                    "params": list(self.parameters())[:2],
                    "lr": self.lr * 0.1
                },
                {
                    "params": list(self.parameters())[2:],
                    "lr": self.lr
                },
            ]

            optimizer = optim.Adam(param_groups)
            lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
            return [optimizer], [lr_scheduler]

    model = CustomClassificationModel()
    dm = ClassifDataModule()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_val_batches=0.1,
                      limit_train_batches=0.5,
                      callbacks=[lr_monitor])
    trainer.fit(model, datamodule=dm)

    assert lr_monitor.lrs, "No learning rates logged"
    assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs)
    assert list(lr_monitor.lrs) == [
        "lr-Adam/pg1", "lr-Adam/pg2"
    ], "Names of learning rates not set correctly"
Exemple #17
0
def test_running_test_pretrained_model_cpu(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer, key="test_acc", thr=0.45)
Exemple #18
0
def test_trainer_properties_restore_ckpt_path(tmpdir):
    """Test that required trainer properties are set correctly when resuming from checkpoint in different
    phases."""

    class CustomClassifModel(ClassificationModel):
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
            return [optimizer], [lr_scheduler]

    model = CustomClassifModel()
    dm = ClassifDataModule()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True)
    trainer_args = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        limit_predict_batches=2,
        logger=False,
        callbacks=[checkpoint_callback],
        num_sanity_val_steps=0,
    )
    trainer = Trainer(**trainer_args)
    trainer.fit(model, datamodule=dm)

    resume_ckpt = str(tmpdir / "last.ckpt")
    state_dict = torch.load(resume_ckpt)

    trainer_args.update({"max_epochs": 3, "enable_checkpointing": False, "callbacks": []})

    class CustomClassifModel(CustomClassifModel):
        def _is_equal(self, a, b):
            if isinstance(a, torch.Tensor):
                return torch.all(torch.eq(a, b))

            if isinstance(a, Mapping):
                return all(self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys())

            return a == b

        def _check_optimizers(self):
            return all(
                self._is_equal(optimizer.state_dict(), state)
                for optimizer, state in zip(self.trainer.optimizers, state_dict["optimizer_states"])
            )

        def _check_schedulers(self):
            return all(
                self._is_equal(config.scheduler.state_dict(), state)
                for config, state in zip(self.trainer.lr_scheduler_configs, state_dict["lr_schedulers"])
            )

        def _check_model_state_dict(self):
            return all(
                self._is_equal(actual, expected)
                for actual, expected in zip(self.state_dict(), state_dict["state_dict"])
            )

        def _test_on_val_test_predict_tune_start(self):
            assert self.trainer.current_epoch == state_dict["epoch"]
            assert self.trainer.global_step == state_dict["global_step"]
            assert self._check_model_state_dict()

            # no optimizes and schedulers are loaded otherwise
            if self.trainer.state.fn != TrainerFn.TUNING:
                return

            assert not self._check_optimizers()
            assert not self._check_schedulers()

        def on_train_start(self):
            if self.trainer.state.fn == TrainerFn.TUNING:
                self._test_on_val_test_predict_tune_start()
            else:
                assert self.trainer.current_epoch == state_dict["epoch"] + 1
                assert self.trainer.global_step == state_dict["global_step"]
                assert self._check_model_state_dict()
                assert self._check_optimizers()
                assert self._check_schedulers()

        def on_validation_start(self):
            if self.trainer.state.fn == TrainerFn.VALIDATING:
                self._test_on_val_test_predict_tune_start()

        def on_test_start(self):
            self._test_on_val_test_predict_tune_start()

    for fn in ("fit", "validate", "test", "predict"):
        model = CustomClassifModel()
        dm = ClassifDataModule()
        trainer_args["auto_scale_batch_size"] = (fn == "tune",)
        trainer = Trainer(**trainer_args)
        trainer_fn = getattr(trainer, fn)
        trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
Exemple #19
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1, accelerator="gpu", devices=2, strategy="dp", default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir)
    trainer.save_checkpoint(hpc_save_path)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        def on_validation_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
            dataloader = dm.val_dataloader()
            tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)

    # new model
    model = CustomModel()

    # validate new model which should load hpc weights
    new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()