Beispiel #1
0
        def on_pretrain_routine_end(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            new_trainer.state.stage = RunningStage.VALIDATING

            dataloader = dm.train_dataloader()
            tpipes.run_model_prediction(self.trainer.lightning_module,
                                        dataloader=dataloader)
            self.on_pretrain_routine_end_called = True
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
    """Verify `test()` on pretrained model."""
    tutils.set_random_main_port()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        accelerator="gpu",
        devices=[0, 1],
        strategy="ddp_spawn",
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
 def on_validation_start(self):
     assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
     dataloader = dm.val_dataloader()
     tpipes.run_model_prediction(self.trainer.lightning_module,
                                 dataloader=dataloader)