示例#1
0
def test_v1_8_0_deprecate_trainer_data_loading_mixin():
    trainer = Trainer(max_epochs=1)
    model = BoringModel()
    dm = BoringDataModule()
    trainer.fit(model, datamodule=dm)

    with pytest.deprecated_call(
        match=r"`TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 and will be removed in v1.8.",
    ):
        trainer.prepare_dataloader(dataloader=model.train_dataloader, shuffle=False)
    with pytest.deprecated_call(
        match=r"`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8.",
    ):
        trainer.request_dataloader(stage=RunningStage.TRAINING)
示例#2
0
def test_ddp_spawn_add_get_queue(tmpdir):
    """Tests add_to_queue/get_from_queue with DDPSpawnStrategy."""

    ddp_spawn_strategy = TestDDPSpawnStrategy()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      num_processes=2,
                      strategy=ddp_spawn_strategy)

    val: float = 1.0
    val_name: str = "val_acc"
    model = BoringCallbackDDPSpawnModel(val_name, val)
    dm = BoringDataModule()
    trainer.fit(model, datamodule=dm)
    assert trainer.callback_metrics[val_name] == torch.tensor(val)
    assert ddp_spawn_strategy.new_test_val == "new_test_val"
def test_ddp_spawn_extra_parameters(tmpdir):
    """Tests if device is set correctely when training for DDPSpawnPlugin."""
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")

    assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
    assert trainer.training_type_plugin.on_gpu
    assert trainer.training_type_plugin.root_device == torch.device("cuda:0")

    val: float = 1.0
    val_name: str = "val_acc"
    model = BoringCallbackDDPSpawnModel(val_name, val)
    dm = BoringDataModule()

    trainer.fit(model, datamodule=dm)
    assert trainer.callback_metrics[val_name] == torch.tensor(val)
    assert model.test_val == "test_val"
def test_init_optimizers_during_evaluation_and_prediction(tmpdir, fn):
    """Test that optimizers is an empty list during evaluation and prediction."""
    class TestModel(BoringModel):
        def configure_optimizers(self):
            optimizer1 = optim.Adam(self.parameters(), lr=0.1)
            optimizer2 = optim.Adam(self.parameters(), lr=0.1)
            lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=1)
            lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, step_size=1)
            return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
    train_fn = getattr(trainer, fn)
    train_fn(TestModel(), datamodule=BoringDataModule(), ckpt_path=None)

    assert len(trainer.lr_schedulers) == 0
    assert len(trainer.optimizers) == 0
    assert len(trainer.optimizer_frequencies) == 0
示例#5
0
def test_ddp_spawn_extra_parameters(tmpdir):
    """Tests if device is set correctly when training for DDPSpawnStrategy and tests add_to_queue/get_from_queue
    with Lightning Module (deprecated way)."""
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      gpus=2,
                      strategy="ddp_spawn")

    assert isinstance(trainer.strategy, DDPSpawnStrategy)
    assert trainer.strategy.root_device == torch.device("cuda:0")

    val: float = 1.0
    val_name: str = "val_acc"
    model = BoringCallbackDDPSpawnModel(val_name, val)
    dm = BoringDataModule()
    trainer.fit(model, datamodule=dm)
    assert trainer.callback_metrics[val_name] == torch.tensor(val)
    assert model.test_val == "test_val"
def test_ddp_spawn_add_get_queue(tmpdir):
    """Tests add_to_queue/get_from_queue with DDPSpawnPlugin."""

    ddp_spawn_plugin = TestDDPSpawnPlugin()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      num_processes=2,
                      strategy=ddp_spawn_plugin)

    val: float = 1.0
    val_name: str = "val_acc"
    model = BoringCallbackDDPSpawnModel(val_name, val)
    dm = BoringDataModule()
    with pytest.deprecated_call(
            match="add_to_queue` method was deprecated in v1.5"):
        trainer.fit(model, datamodule=dm)
    assert trainer.callback_metrics[val_name] == torch.tensor(val)
    assert ddp_spawn_plugin.new_test_val == "new_test_val"
示例#7
0
 def reset_instances(self):
     warning_cache.clear()
     return BoringDataModule(), BoringModel(), Trainer()
示例#8
0
    def __init__(self):
        super().__init__()
        self.train_dataloader = None
        self.val_dataloader = None
        self.test_dataloader = None
        self.predict_dataloader = None


@pytest.mark.parametrize(
    "instance,available",
    [
        (None, True),
        (BoringModel().train_dataloader(), True),
        (BoringModel(), True),
        (NoDataLoaderModel(), False),
        (BoringDataModule(), True),
    ],
)
def test_dataloader_source_available(instance, available):
    """Test the availability check for _DataLoaderSource."""
    source = _DataLoaderSource(instance=instance, name="train_dataloader")
    assert source.is_defined() is available


def test_dataloader_source_direct_access():
    """Test requesting a dataloader when the source is already a dataloader."""
    dataloader = BoringModel().train_dataloader()
    source = _DataLoaderSource(instance=dataloader, name="any")
    assert not source.is_module()
    assert source.is_defined()
    assert source.dataloader() is dataloader