Esempio n. 1
0
def main():
    seed_everything(4321)

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument("--trainer_method", default="fit")
    parser.add_argument("--tmpdir")
    parser.add_argument("--workdir")
    parser.set_defaults(gpus=2)
    parser.set_defaults(strategy="ddp")
    args = parser.parse_args()

    dm = ClassifDataModule()
    model = ClassificationModel()
    trainer = Trainer.from_argparse_args(args)

    if args.trainer_method == "fit":
        trainer.fit(model, datamodule=dm)
        result = None
    elif args.trainer_method == "test":
        result = trainer.test(model, datamodule=dm)
    elif args.trainer_method == "fit_test":
        trainer.fit(model, datamodule=dm)
        result = trainer.test(model, datamodule=dm)
    else:
        raise ValueError(f"Unsupported: {args.trainer_method}")

    result_ext = {
        "status": "complete",
        "method": args.trainer_method,
        "result": result
    }
    file_path = os.path.join(args.tmpdir, "ddp.result")
    torch.save(result_ext, file_path)
def test_callbacks_references_fit_ckpt_path(tmpdir):
    """Test that resuming from a checkpoint sets references as expected."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    args = {
        "default_root_dir": tmpdir,
        "max_steps": 1,
        "logger": False,
        "limit_val_batches": 2,
        "num_sanity_val_steps": 0,
    }

    # initial training
    checkpoint = ModelCheckpoint(dirpath=tmpdir,
                                 monitor="val_loss",
                                 save_last=True)
    trainer = Trainer(**args, callbacks=[checkpoint])
    assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback
    trainer.fit(model, datamodule=dm)

    # resumed training
    new_checkpoint = ModelCheckpoint(dirpath=tmpdir,
                                     monitor="val_loss",
                                     save_last=True)
    # pass in a new checkpoint object, which should take
    # precedence over the one in the last.ckpt file
    trainer = Trainer(**args, callbacks=[new_checkpoint])
    assert checkpoint is not new_checkpoint
    assert new_checkpoint is trainer.callbacks[
        -1] is trainer.checkpoint_callback
    trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))
Esempio n. 3
0
def test_callbacks_state_resume_from_checkpoint(tmpdir):
    """Test that resuming from a checkpoint restores callbacks that persist state."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    callback_capture = CaptureCallbacksBeforeTraining()

    def get_trainer_args():
        checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
        trainer_args = dict(
            default_root_dir=tmpdir,
            max_steps=1,
            logger=False,
            callbacks=[checkpoint, callback_capture],
            limit_val_batches=2,
        )
        assert checkpoint.best_model_path == ""
        assert checkpoint.best_model_score is None
        return trainer_args

    # initial training
    trainer = Trainer(**get_trainer_args())
    trainer.fit(model, datamodule=dm)
    callbacks_before_resume = deepcopy(trainer.callbacks)

    # resumed training
    trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt"))
    trainer.fit(model, datamodule=dm)

    assert len(callbacks_before_resume) == len(callback_capture.callbacks)

    for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
        if isinstance(before, ModelCheckpoint):
            assert before.best_model_path == after.best_model_path
            assert before.best_model_score == after.best_model_score
def test_datamodule_parameter(tmpdir):
    """Test that the datamodule parameter works."""
    seed_everything(1)

    dm = ClassifDataModule()
    model = ClassificationModel()

    before_lr = model.lr
    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)

    lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
    after_lr = lrfinder.suggestion()
    model.lr = after_lr

    assert before_lr != after_lr, "Learning rate was not altered after running learning rate finder"
def test_full_loop(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      enable_model_summary=False,
                      deterministic=True)

    # fit model
    trainer.fit(model, dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # validate
    result = trainer.validate(model, dm)
    assert dm.trainer is not None
    assert result[0]["val_acc"] > 0.7

    # test
    result = trainer.test(model, dm)
    assert dm.trainer is not None
    assert result[0]["test_acc"] > 0.6
def test_full_loop(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        weights_summary=None,
        deterministic=True,
    )

    # fit model
    result = trainer.fit(model, dm)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert dm.trainer is not None
    assert result

    # validate
    result = trainer.validate(datamodule=dm)
    assert dm.trainer is not None
    assert result[0]['val_acc'] > 0.7

    # test
    result = trainer.test(datamodule=dm)
    assert dm.trainer is not None
    assert result[0]['test_acc'] > 0.6
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
    """Verify `test()` on pretrained model."""
    tutils.set_random_main_port()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        accelerator="gpu",
        devices=[0, 1],
        strategy="ddp_spawn",
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir)))

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)
    pretrained_model.cpu()

    dataloaders = dm.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
Esempio n. 8
0
def test_train_val_loop_only(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    model.validation_step = None
    model.validation_step_end = None
    model.validation_epoch_end = None

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      enable_model_summary=False)

    # fit model
    trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert trainer.callback_metrics["train_loss"] < 1.0
def test_early_stopping_no_val_step(tmpdir):
    """Test that early stopping callback falls back to training metrics when no validation defined."""

    model = ClassificationModel()
    dm = ClassifDataModule()
    model.validation_step = None
    model.val_dataloader = None

    stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[stopping],
        overfit_batches=0.20,
        max_epochs=10,
    )
    trainer.fit(model, datamodule=dm)

    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert trainer.current_epoch < trainer.max_epochs - 1
Esempio n. 10
0
def test_callbacks_state_fit_ckpt_path(tmpdir):
    """Test that resuming from a checkpoint restores callbacks that persist state."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    callback_capture = CaptureCallbacksBeforeTraining()

    def get_trainer_args():
        checkpoint = ModelCheckpoint(dirpath=tmpdir,
                                     monitor="val_loss",
                                     save_last=True)
        trainer_args = dict(
            default_root_dir=tmpdir,
            limit_train_batches=1,
            limit_val_batches=2,
            max_epochs=1,
            logger=False,
            callbacks=[checkpoint, callback_capture],
        )
        assert checkpoint.best_model_path == ""
        assert checkpoint.best_model_score is None
        return trainer_args

    # initial training
    trainer = Trainer(**get_trainer_args())
    with pytest.deprecated_call(
            match=
            "`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"
    ):
        trainer.fit(model, datamodule=dm)

    callbacks_before_resume = deepcopy(trainer.callbacks)

    # resumed training
    trainer = Trainer(**get_trainer_args())
    with pytest.deprecated_call(
            match=
            "`Callback.on_pretrain_routine_end` hook has been deprecated in v1.6"
    ):
        trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt"))

    assert len(callbacks_before_resume) == len(callback_capture.callbacks)

    for before, after in zip(callbacks_before_resume,
                             callback_capture.callbacks):
        if isinstance(before, ModelCheckpoint):
            for attribute in (
                    "best_model_path",
                    "best_model_score",
                    "best_k_models",
                    "kth_best_model_path",
                    "kth_value",
                    "last_model_path",
            ):
                assert getattr(before, attribute) == getattr(after, attribute)
Esempio n. 11
0
def test_train_val_loop_only(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    model.validation_step = None
    model.validation_step_end = None
    model.validation_epoch_end = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        weights_summary=None,
    )

    # fit model
    trainer.fit(model, datamodule=dm)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert trainer.callback_metrics['train_loss'] < 1.0
Esempio n. 12
0
def test_fit_csv_logger(tmpdir):
    dm = ClassifDataModule()
    model = ClassificationModel()
    logger = CSVLogger(save_dir=tmpdir)
    trainer = Trainer(default_root_dir=tmpdir,
                      max_steps=10,
                      logger=logger,
                      log_every_n_steps=1)
    trainer.fit(model, datamodule=dm)
    metrics_file = os.path.join(logger.log_dir,
                                ExperimentWriter.NAME_METRICS_FILE)
    assert os.path.isfile(metrics_file)
Esempio n. 13
0
def test_optimization(tmpdir):
    seed_everything(42)

    dm = ClassifDataModule(length=1024)
    model = ClassificationModel()

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      accelerator="hpu",
                      devices=1)

    # fit model
    trainer.fit(model, dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # validate
    result = trainer.validate(datamodule=dm)
    assert dm.trainer is not None
    assert result[0]["val_acc"] > 0.7

    # test
    result = trainer.test(model, datamodule=dm)
    assert dm.trainer is not None
    test_result = result[0]["test_acc"]
    assert test_result > 0.6

    # test saved model
    model_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(model_path)

    model = ClassificationModel.load_from_checkpoint(model_path)

    trainer = Trainer(default_root_dir=tmpdir, accelerator="hpu", devices=1)

    result = trainer.test(model, datamodule=dm)
    saved_result = result[0]["test_acc"]
    assert saved_result == test_result
Esempio n. 14
0
def test_suggestion_parameters_work(tmpdir):
    """Test that default skipping does not alter results in basic case."""

    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=3)

    lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
    lr1 = lrfinder.suggestion(skip_begin=10)  # default
    lr2 = lrfinder.suggestion(skip_begin=150)  # way too high, should have an impact

    assert lr1 != lr2, "Skipping parameter did not influence learning rate"
Esempio n. 15
0
def test_running_test_pretrained_model_cpu(tmpdir):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    dm = ClassifDataModule()
    model = ClassificationModel()

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[checkpoint],
        logger=logger,
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model, datamodule=dm)

    # test we have good test accuracy
    tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45)
Esempio n. 16
0
def test_multi_gpu_none_backend(tmpdir):
    """Make sure when using multiple GPUs the user can't use `distributed_backend = None`."""
    tutils.set_random_master_port()
    trainer_options = dict(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=1,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        gpus=2,
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, dm)
Esempio n. 17
0
def test_multi_gpu_none_backend(tmpdir):
    """Make sure when using multiple GPUs the user can't use `accelerator = None`."""
    tutils.set_random_main_port()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        gpus=2,
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, dm)
Esempio n. 18
0
def test_multi_gpu_early_stop_ddp_spawn(tmpdir):
    tutils.set_random_main_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        callbacks=[EarlyStopping(monitor="train_acc")],
        max_epochs=50,
        limit_train_batches=10,
        limit_val_batches=10,
        gpus=[0, 1],
        strategy="ddp_spawn",
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, dm)
Esempio n. 19
0
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_master_port()

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='dp',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction_eval_model_template(pretrained_model, dataloader)
def test_resume_early_stopping_from_checkpoint(tmpdir):
    """Prevent regressions to bugs:

    https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
    https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
    """
    seed_everything(42)
    model = ClassificationModel()
    dm = ClassifDataModule()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir,
                                          monitor="train_loss",
                                          save_top_k=1)
    early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss")
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback, checkpoint_callback],
        num_sanity_val_steps=0,
        max_epochs=4,
    )
    trainer.fit(model, datamodule=dm)

    assert len(early_stop_callback.saved_states) == 4

    checkpoint_filepath = checkpoint_callback.kth_best_model_path
    # ensure state is persisted properly
    checkpoint = torch.load(checkpoint_filepath)
    # the checkpoint saves "epoch + 1"
    early_stop_callback_state = early_stop_callback.saved_states[
        checkpoint["epoch"] - 1]
    assert 4 == len(early_stop_callback.saved_states)
    es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
    assert checkpoint["callbacks"][es_name] == early_stop_callback_state

    # ensure state is reloaded properly (assertion in the callback)
    early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state,
                                                   monitor="train_loss")
    new_trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        resume_from_checkpoint=checkpoint_filepath,
        callbacks=[early_stop_callback],
    )

    with pytest.raises(MisconfigurationException,
                       match=r"You restored a checkpoint with current_epoch"):
        new_trainer.fit(model)
def test_early_stopping_no_extraneous_invocations(tmpdir):
    """Test to ensure that callback methods aren't being invoked outside of the callback handler."""
    model = ClassificationModel()
    dm = ClassifDataModule()
    early_stop_callback = EarlyStopping(monitor='train_loss')
    expected_count = 4
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback],
        limit_train_batches=4,
        limit_val_batches=4,
        max_epochs=expected_count,
    )
    trainer.fit(model, datamodule=dm)

    assert trainer.early_stopping_callback == early_stop_callback
    assert trainer.early_stopping_callbacks == [early_stop_callback]
    assert len(trainer.dev_debugger.early_stopping_history) == expected_count
Esempio n. 22
0
def test_multi_cpu_model_ddp(tmpdir):
    """Make sure DDP works."""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        gpus=None,
        num_processes=2,
        accelerator='ddp_cpu',
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, data=dm, on_gpu=False)
Esempio n. 23
0
def test_multi_cpu_model_ddp(tmpdir):
    """Make sure DDP works."""
    tutils.set_random_main_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        gpus=None,
        num_processes=2,
        strategy="ddp_spawn",
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    tpipes.run_model_test(trainer_options, model, data=dm, on_gpu=False)
def test_early_stopping_no_extraneous_invocations(tmpdir):
    """Test to ensure that callback methods aren't being invoked outside of the callback handler."""
    model = ClassificationModel()
    dm = ClassifDataModule()
    early_stop_callback = EarlyStopping(monitor="train_loss")
    early_stop_callback._run_early_stopping_check = Mock()
    expected_count = 4
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback],
        limit_train_batches=4,
        limit_val_batches=4,
        max_epochs=expected_count,
        enable_checkpointing=False,
    )
    trainer.fit(model, datamodule=dm)

    assert trainer.early_stopping_callback == early_stop_callback
    assert trainer.early_stopping_callbacks == [early_stop_callback]
    assert early_stop_callback._run_early_stopping_check.call_count == expected_count
Esempio n. 25
0
def test_try_resume_from_non_existing_checkpoint(tmpdir):
    """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
    dm = ClassifDataModule()
    model = ClassificationModel()
    checkpoint_cb = ModelCheckpoint(dirpath=tmpdir,
                                    monitor="val_loss",
                                    save_last=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=False,
        callbacks=[checkpoint_cb],
        limit_train_batches=2,
        limit_val_batches=2,
    )
    # Generate checkpoint `last.ckpt` with BoringModel
    trainer.fit(model, datamodule=dm)
    # `True` if resume/restore successfully else `False`
    assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"),
                                                trainer.on_gpu)
    assert not trainer.checkpoint_connector.restore(
        str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
Esempio n. 26
0
def test_running_test_pretrained_model_distrib_dp(tmpdir):
    """Verify `test()` on pretrained model."""

    tutils.set_random_master_port()

    class CustomClassificationModelDP(ClassificationModel):
        def _step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            return {'logits': logits, 'y': y}

        def training_step(self, batch, batch_idx):
            _, y = batch
            out = self._step(batch, batch_idx)
            loss = F.cross_entropy(out['logits'], y)
            return loss

        def validation_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def test_step(self, batch, batch_idx):
            return self._step(batch, batch_idx)

        def validation_step_end(self, outputs):
            self.log('val_acc', self.valid_acc(outputs['logits'],
                                               outputs['y']))

    dm = ClassifDataModule()
    model = CustomClassificationModelDP(lr=0.1)

    # exp file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=5,
        limit_val_batches=5,
        callbacks=[checkpoint],
        logger=logger,
        gpus=[0, 1],
        accelerator='dp',
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=dm)

    # correct result and ok accuracy
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    pretrained_model = ClassificationModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)

    # run test set
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
    pretrained_model.cpu()

    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        tpipes.run_prediction(pretrained_model, dataloader)