Exemplo n.º 1
0
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    # simulate setting slurm flags
    tutils.set_random_main_port()

    model = AMPTestModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=[0],
        strategy="ddp_spawn",
        precision=16,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    # correct result and ok accuracy
    assert trainer.state.finished, "amp + ddp model failed to complete"

    # test root model address
    assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment)
    assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc"
    assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23]") == "abc23"
    assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23"
    generated = trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]")
    assert generated == "abc23"
Exemplo n.º 2
0
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    # simulate setting slurm flags
    tutils.set_random_master_port()

    model = BoringModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=[0],
        accelerator='ddp_spawn',
        precision=16,
        callbacks=[checkpoint],
        logger=logger,
    )
    _ = trainer.fit(model)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete'

    # test root model address
    assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
    assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc') == 'abc'
    assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23'
    generated = trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]')
    assert generated == 'abc23'
Exemplo n.º 3
0
def run_model_test(
    trainer_options,
    model: LightningModule,
    data: LightningDataModule = None,
    on_gpu: bool = True,
    version=None,
    with_hpc: bool = True,
    min_acc: float = 0.25,
):
    reset_seed()
    save_dir = trainer_options["default_root_dir"]

    # logger file to get meta
    logger = get_default_logger(save_dir, version=version)
    trainer_options.update(logger=logger)
    trainer = Trainer(**trainer_options)
    initial_values = torch.tensor(
        [torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model, datamodule=data)
    post_train_values = torch.tensor(
        [torch.sum(torch.abs(x)) for x in model.parameters()])

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    # Check that the model is actually changed post-training
    change_ratio = torch.norm(initial_values - post_train_values)
    assert change_ratio > 0.1, f"the model is changed of {change_ratio}"

    # test model loading
    pretrained_model = load_model_from_checkpoint(
        logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model.test_dataloader(
    ) if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model, BoringModel):
        for dataloader in test_loaders:
            run_prediction_eval_model_template(model,
                                               dataloader,
                                               min_acc=min_acc)

    if with_hpc:
        if trainer._distrib_type in (DistributedType.DDP,
                                     DistributedType.DDP_SPAWN,
                                     DistributedType.DDP2):
            # on hpc this would work fine... but need to hack it for the purpose of the test
            trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers(
                pretrained_model)

        # test HPC saving
        trainer.checkpoint_connector.hpc_save(save_dir, logger)
        # test HPC loading
        checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(
            save_dir)
        trainer.checkpoint_connector.restore(checkpoint_path)
Exemplo n.º 4
0
def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server,
                                       url_ckpt):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv('TORCH_HOME', tmpdir)

    model = BoringModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=logger,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)

    # traning complete
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    # save model
    new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir),
                                'hparams.yaml')
    hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}'
    ckpt_path = hparams_url if url_ckpt else new_weights_path

    class CurrentModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.c_d3 = torch.nn.Linear(7, 7)

    CurrentModel.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        hparams_file=hparams_path,
        strict=False,
    )

    with pytest.raises(
            RuntimeError,
            match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'
    ):
        CurrentModel.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            hparams_file=hparams_path,
            strict=True,
        )
Exemplo n.º 5
0
def test_model_saving_loading(tmpdir):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    model = BoringModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=logger,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
        default_root_dir=tmpdir,
    )
    trainer.fit(model)

    # traning complete
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    batch = next(iter(dataloaders[0]))

    # generate preds before saving model
    model.eval()
    pred_before_saving = model(batch)

    # save model
    new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
    hparams_path = os.path.join(hparams_path, 'hparams.yaml')
    model_2 = BoringModel.load_from_checkpoint(
        checkpoint_path=new_weights_path,
        hparams_file=hparams_path,
    )
    model_2.eval()

    # make prediction
    # assert that both predictions are the same
    new_pred = model_2(batch)
    assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
Exemplo n.º 6
0
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
    """Verify `test()` on pretrained model."""
    tutils.set_random_master_port()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='ddp_spawn',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction_eval_model_template(pretrained_model,
                                                  dataloader,
                                                  min_acc=0.1)
Exemplo n.º 7
0
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server,
                                       url_ckpt):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv("TORCH_HOME", tmpdir)

    model = BoringModel()
    # Extra layer
    model.c_d3 = torch.nn.Linear(32, 32)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=logger,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)

    # traning complete
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # save model
    new_weights_path = os.path.join(tmpdir, "save_test.ckpt")
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir),
                                "hparams.yaml")
    hparams_url = f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}"
    ckpt_path = hparams_url if url_ckpt else new_weights_path

    BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path,
                                     hparams_file=hparams_path,
                                     strict=False)

    with pytest.raises(
            RuntimeError,
            match=
            r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
        BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path,
                                         hparams_file=hparams_path,
                                         strict=True)
Exemplo n.º 8
0
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer."""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    parser = ArgumentParser(add_help=False)
    args = parser.parse_args()
    args.logger = logger

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
Exemplo n.º 9
0
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_main_port()

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        accelerator="gpu",
        devices=[0, 1],
        strategy="dp",
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = CustomClassificationModelDP.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_model_prediction(pretrained_model, dataloader)
Exemplo n.º 10
0
def run_model_test(
    trainer_options,
    model: LightningModule,
    data: LightningDataModule = None,
    on_gpu: bool = True,
    version=None,
    with_hpc: bool = True,
    min_acc: float = 0.25,
):
    reset_seed()
    save_dir = trainer_options["default_root_dir"]

    # logger file to get meta
    logger = get_default_logger(save_dir, version=version)
    trainer_options.update(logger=logger)
    trainer = Trainer(**trainer_options)
    initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model, datamodule=data)
    post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    # Check that the model is actually changed post-training
    change_ratio = torch.norm(initial_values - post_train_values)
    assert change_ratio > 0.03, f"the model is changed of {change_ratio}"

    # test model loading
    _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model.test_dataloader() if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model, BoringModel):
        for dataloader in test_loaders:
            run_model_prediction(model, dataloader, min_acc=min_acc)

    if with_hpc:
        # test HPC saving
        # save logger to make sure we get all the metrics
        if logger:
            logger.finalize("finished")
        hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir)
        trainer.save_checkpoint(hpc_save_path)
        # test HPC loading
        checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
        trainer.checkpoint_connector.restore(checkpoint_path)
Exemplo n.º 11
0
def test_running_test_no_val(tmpdir):
    """Verify `test()` works on a model with no `val_dataloader`.

    It performs train and test only
    """

    class ModelTrainTest(BoringModel):
        def val_dataloader(self):
            pass

        def test_step(self, *args, **kwargs):
            output = super().test_step(*args, **kwargs)
            self.log("test_loss", output["y"])
            return output

    model = ModelTrainTest()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        limit_test_batches=0.2,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer, key="test_loss")
Exemplo n.º 12
0
def test_running_test_after_fitting(tmpdir):
    """Verify test() on fitted model."""

    class ModelTrainValTest(BoringModel):

        def validation_step(self, *args, **kwargs):
            output = super().validation_step(*args, **kwargs)
            self.log('val_loss', output['x'])
            return output

        def test_step(self, *args, **kwargs):
            output = super().test_step(*args, **kwargs)
            self.log('test_loss', output['y'])
            return output

    model = ModelTrainValTest()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        limit_test_batches=0.2,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer, key='test_loss', thr=0.5)
Exemplo n.º 13
0
def test_running_test_pretrained_model_cpu(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45)
Exemplo n.º 14
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = BoringModel()

    trainer_options = dict(max_epochs=1,
                           gpus=2,
                           accelerator='dp',
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['callbacks'] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    trainer.fit(model)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        # set the epoch start hook so we can predict before the model does the full training
        def on_train_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            dp_model = new_trainer.model
            dp_model.eval()
            dp_model.module.module.running_stage = RunningStage.EVALUATING

            dataloader = self.train_dataloader()
            tpipes.run_prediction(self.trainer.lightning_module, dataloader)
            self.on_train_start_called = True

    # new model
    model = CustomModel()

    # fit new model which should load hpc weights
    new_trainer.fit(model)
    assert model.on_train_start_called

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Exemplo n.º 15
0
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_master_port()

    class CustomClassificationModelDP(ClassificationModel):
        def _step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            return {'logits': logits, 'y': y}

        def training_step(self, batch, batch_idx):
            _, y = batch
            out = self._step(batch, batch_idx)
            loss = F.cross_entropy(out['logits'], y)
            return loss

        def validation_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def test_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def validation_step_end(self, outputs):
            self.log('val_acc', self.valid_acc(outputs['logits'],
                                               outputs['y']))

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='dp',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction(pretrained_model, dataloader)
Exemplo n.º 16
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           accelerator="gpu",
                           devices=2,
                           strategy="dp",
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir)
    trainer.save_checkpoint(hpc_save_path)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        def on_validation_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
            dataloader = dm.val_dataloader()
            tpipes.run_model_prediction(self.trainer.lightning_module,
                                        dataloader=dataloader)

    # new model
    model = CustomModel()

    # validate new model which should load hpc weights
    new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Exemplo n.º 17
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           gpus=2,
                           strategy="dp",
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer._is_slurm_managing_tasks = True
    trainer.fit(model, datamodule=dm)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_pretrain_routine_end_called = False

        # set the epoch start hook so we can predict before the model does the full training
        def on_pretrain_routine_end(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            new_trainer.state.stage = RunningStage.VALIDATING

            dataloader = dm.train_dataloader()
            tpipes.run_prediction_eval_model_template(
                self.trainer.lightning_module, dataloader=dataloader)
            self.on_pretrain_routine_end_called = True

    # new model
    model = CustomModel()

    # fit new model which should load hpc weights
    new_trainer.fit(model, datamodule=dm)
    assert model.on_pretrain_routine_end_called

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Exemplo n.º 18
0
def test_cpu_slurm_save_load(tmpdir):
    """Verify model save/load/checkpoint on CPU."""
    model = BoringModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)
    version = logger.version

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert trainer.state.finished, "cpu model failed to complete"

    # predict with trained model before saving
    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    model.eval()
    pred_before_saving = model(batch)

    # test HPC saving
    # simulate snapshot on slurm
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer.checkpoint_connector.hpc_save_path(
        trainer.weights_save_path)
    trainer.save_checkpoint(hpc_save_path)
    assert os.path.exists(hpc_save_path)

    # new logger file to get meta
    logger = tutils.get_default_logger(tmpdir, version=version)

    model = BoringModel()

    class _StartCallback(Callback):
        # set the epoch start hook so we can predict before the model does the full training
        def on_train_epoch_start(self, trainer, model):
            assert trainer.global_step == real_global_step and trainer.global_step > 0
            # predict with loaded model to make sure answers are the same
            mode = model.training
            model.eval()
            new_pred = model(batch)
            assert torch.eq(pred_before_saving, new_pred).all()
            model.train(mode)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
        callbacks=[_StartCallback(),
                   ModelCheckpoint(dirpath=tmpdir)],
    )
    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)