示例#1
0
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer."""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    parser = ArgumentParser(add_help=False)
    args = parser.parse_args()
    args.logger = logger

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
def run_model_test(
    trainer_options,
    model: LightningModule,
    data: LightningDataModule = None,
    version=None,
    with_hpc: bool = True,
    min_acc: float = 0.25,
):
    reset_seed()
    save_dir = trainer_options["default_root_dir"]

    # logger file to get meta
    logger = get_default_logger(save_dir, version=version)
    trainer_options.update(logger=logger)
    trainer = Trainer(**trainer_options)
    initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model, datamodule=data)
    post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    # Check that the model is actually changed post-training
    change_ratio = torch.norm(initial_values - post_train_values)
    assert change_ratio > 0.03, f"the model is changed of {change_ratio}"

    # test model loading
    _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model.test_dataloader() if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model, BoringModel):
        for dataloader in test_loaders:
            run_model_prediction(model, dataloader, min_acc=min_acc)

    if with_hpc:
        # test HPC saving
        # save logger to make sure we get all the metrics
        if logger:
            logger.finalize("finished")
        hpc_save_path = trainer._checkpoint_connector.hpc_save_path(save_dir)
        trainer.save_checkpoint(hpc_save_path)
        # test HPC loading
        checkpoint_path = trainer._checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
        trainer._checkpoint_connector.restore(checkpoint_path)
示例#3
0
def test_model_saving_loading(tmpdir):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    model = BoringModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=logger,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
        default_root_dir=tmpdir,
    )
    trainer.fit(model)

    # traning complete
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    batch = next(iter(dataloaders[0]))

    # generate preds before saving model
    model.eval()
    pred_before_saving = model(batch)

    # save model
    new_weights_path = os.path.join(tmpdir, "save_test.ckpt")
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
    hparams_path = os.path.join(hparams_path, "hparams.yaml")
    model_2 = BoringModel.load_from_checkpoint(checkpoint_path=new_weights_path, hparams_file=hparams_path)
    model_2.eval()

    # make prediction
    # assert that both predictions are the same
    new_pred = model_2(batch)
    assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
示例#4
0
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
    """Verify `test()` on pretrained model."""
    tutils.set_random_main_port()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        accelerator="gpu",
        devices=[0, 1],
        strategy="ddp_spawn",
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
示例#5
0
def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv("TORCH_HOME", tmpdir)

    model = BoringModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=logger,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)

    # traning complete
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # save model
    new_weights_path = os.path.join(tmpdir, "save_test.ckpt")
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), "hparams.yaml")
    ckpt_url = f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}"
    ckpt_path = ckpt_url if url_ckpt else new_weights_path

    class CurrentModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.c_d3 = torch.nn.Linear(7, 7)

    CurrentModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False)

    with pytest.raises(RuntimeError, match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
        CurrentModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=True)
示例#6
0
def test_running_test_no_val(tmpdir):
    """Verify `test()` works on a model with no `val_dataloader`.

    It performs train and test only
    """
    class ModelTrainTest(BoringModel):
        def val_dataloader(self):
            pass

        def test_step(self, *args, **kwargs):
            output = super().test_step(*args, **kwargs)
            self.log("test_loss", output["y"])
            return output

    model = ModelTrainTest()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        limit_test_batches=0.2,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer, key="test_loss")
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    # simulate setting slurm flags
    tutils.set_random_main_port()

    model = AMPTestModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        accelerator="gpu",
        devices=[0],
        strategy="ddp_spawn",
        precision=16,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    # correct result and ok accuracy
    assert trainer.state.finished, "amp + ddp model failed to complete"

    # test root model address
    assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment)
    assert trainer.strategy.cluster_environment.resolve_root_node_address(
        "abc") == "abc"
    assert trainer.strategy.cluster_environment.resolve_root_node_address(
        "abc[23]") == "abc23"
    assert trainer.strategy.cluster_environment.resolve_root_node_address(
        "abc[23-24]") == "abc23"
    generated = trainer.strategy.cluster_environment.resolve_root_node_address(
        "abc[23-24, 45-40, 40]")
    assert generated == "abc23"
示例#8
0
def test_running_test_after_fitting(tmpdir):
    """Verify test() on fitted model."""
    class ModelTrainValTest(BoringModel):
        def validation_step(self, *args, **kwargs):
            output = super().validation_step(*args, **kwargs)
            self.log("val_loss", output["x"])
            return output

        def test_step(self, *args, **kwargs):
            output = super().test_step(*args, **kwargs)
            self.log("test_loss", output["y"])
            return output

    model = ModelTrainValTest()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        limit_test_batches=0.2,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer, key="test_loss", thr=0.5)
示例#9
0
def test_running_test_pretrained_model_cpu(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer, key="test_acc", thr=0.45)
示例#10
0
def test_cpu_slurm_save_load(tmpdir):
    """Verify model save/load/checkpoint on CPU."""
    model = BoringModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)
    version = logger.version

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert trainer.state.finished, "cpu model failed to complete"

    # predict with trained model before saving
    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    model.eval()
    pred_before_saving = model(batch)

    # test HPC saving
    # simulate snapshot on slurm
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer._checkpoint_connector.hpc_save_path(
        trainer.default_root_dir)
    trainer.save_checkpoint(hpc_save_path)
    assert os.path.exists(hpc_save_path)

    # new logger file to get meta
    logger = tutils.get_default_logger(tmpdir, version=version)

    model = BoringModel()

    class _StartCallback(Callback):
        # set the epoch start hook so we can predict before the model does the full training
        def on_train_epoch_start(self, trainer, model):
            assert trainer.global_step == real_global_step and trainer.global_step > 0
            # predict with loaded model to make sure answers are the same
            mode = model.training
            model.eval()
            new_pred = model(batch)
            assert torch.eq(pred_before_saving, new_pred).all()
            model.train(mode)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
        callbacks=[_StartCallback(),
                   ModelCheckpoint(dirpath=tmpdir)],
    )
    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)
示例#11
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1, accelerator="gpu", devices=2, strategy="dp", default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir)
    trainer.save_checkpoint(hpc_save_path)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        def on_validation_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
            dataloader = dm.val_dataloader()
            tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)

    # new model
    model = CustomModel()

    # validate new model which should load hpc weights
    new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()