Beispiel #1
0
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    # simulate setting slurm flags
    tutils.set_random_main_port()

    model = AMPTestModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=[0],
        strategy="ddp_spawn",
        precision=16,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    # correct result and ok accuracy
    assert trainer.state.finished, "amp + ddp model failed to complete"

    # test root model address
    assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment)
    assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc"
    assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23]") == "abc23"
    assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23"
    generated = trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]")
    assert generated == "abc23"
Beispiel #2
0
def test_amp_gpu_ddp_slurm_managed(tmpdir):
    """Make sure DDP + AMP work."""
    # simulate setting slurm flags
    tutils.set_random_master_port()

    model = BoringModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=[0],
        accelerator='ddp_spawn',
        precision=16,
        callbacks=[checkpoint],
        logger=logger,
    )
    _ = trainer.fit(model)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete'

    # test root model address
    assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
    assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc') == 'abc'
    assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23'
    generated = trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]')
    assert generated == 'abc23'
Beispiel #3
0
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
    """Verify `test()` on pretrained model."""
    tutils.set_random_master_port()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='ddp_spawn',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction_eval_model_template(pretrained_model,
                                                  dataloader,
                                                  min_acc=0.1)
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_main_port()

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        accelerator="gpu",
        devices=[0, 1],
        strategy="dp",
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = CustomClassificationModelDP.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_model_prediction(pretrained_model, dataloader)
def test_running_test_after_fitting(tmpdir):
    """Verify test() on fitted model."""

    class ModelTrainValTest(BoringModel):

        def validation_step(self, *args, **kwargs):
            output = super().validation_step(*args, **kwargs)
            self.log('val_loss', output['x'])
            return output

        def test_step(self, *args, **kwargs):
            output = super().test_step(*args, **kwargs)
            self.log('test_loss', output['y'])
            return output

    model = ModelTrainValTest()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        limit_test_batches=0.2,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer, key='test_loss', thr=0.5)
Beispiel #6
0
def test_running_test_no_val(tmpdir):
    """Verify `test()` works on a model with no `val_dataloader`.

    It performs train and test only
    """

    class ModelTrainTest(BoringModel):
        def val_dataloader(self):
            pass

        def test_step(self, *args, **kwargs):
            output = super().test_step(*args, **kwargs)
            self.log("test_loss", output["y"])
            return output

    model = ModelTrainTest()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        limit_test_batches=0.2,
        callbacks=[checkpoint],
        logger=logger,
    )
    trainer.fit(model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"

    trainer.test()

    # test we have good test accuracy
    tutils.assert_ok_model_acc(trainer, key="test_loss")
Beispiel #7
0
def test_running_test_pretrained_model_cpu(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45)
Beispiel #8
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = BoringModel()

    trainer_options = dict(max_epochs=1,
                           gpus=2,
                           accelerator='dp',
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['callbacks'] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    trainer.fit(model)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        # set the epoch start hook so we can predict before the model does the full training
        def on_train_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            dp_model = new_trainer.model
            dp_model.eval()
            dp_model.module.module.running_stage = RunningStage.EVALUATING

            dataloader = self.train_dataloader()
            tpipes.run_prediction(self.trainer.lightning_module, dataloader)
            self.on_train_start_called = True

    # new model
    model = CustomModel()

    # fit new model which should load hpc weights
    new_trainer.fit(model)
    assert model.on_train_start_called

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Beispiel #9
0
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_master_port()

    class CustomClassificationModelDP(ClassificationModel):
        def _step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            return {'logits': logits, 'y': y}

        def training_step(self, batch, batch_idx):
            _, y = batch
            out = self._step(batch, batch_idx)
            loss = F.cross_entropy(out['logits'], y)
            return loss

        def validation_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def test_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def validation_step_end(self, outputs):
            self.log('val_acc', self.valid_acc(outputs['logits'],
                                               outputs['y']))

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='dp',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction(pretrained_model, dataloader)
Beispiel #10
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           gpus=2,
                           strategy="dp",
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer._is_slurm_managing_tasks = True
    trainer.fit(model, datamodule=dm)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_pretrain_routine_end_called = False

        # set the epoch start hook so we can predict before the model does the full training
        def on_pretrain_routine_end(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            new_trainer.state.stage = RunningStage.VALIDATING

            dataloader = dm.train_dataloader()
            tpipes.run_prediction_eval_model_template(
                self.trainer.lightning_module, dataloader=dataloader)
            self.on_pretrain_routine_end_called = True

    # new model
    model = CustomModel()

    # fit new model which should load hpc weights
    new_trainer.fit(model, datamodule=dm)
    assert model.on_pretrain_routine_end_called

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           accelerator="gpu",
                           devices=2,
                           strategy="dp",
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options["logger"] = logger
    trainer_options["callbacks"] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    # save logger to make sure we get all the metrics
    if logger:
        logger.finalize("finished")
    hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir)
    trainer.save_checkpoint(hpc_save_path)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options["logger"] = new_logger
    trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options["limit_train_batches"] = 0.5
    trainer_options["limit_val_batches"] = 0.2
    trainer_options["max_epochs"] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_train_start_called = False

        def on_validation_start(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
            dataloader = dm.val_dataloader()
            tpipes.run_model_prediction(self.trainer.lightning_module,
                                        dataloader=dataloader)

    # new model
    model = CustomModel()

    # validate new model which should load hpc weights
    new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()