def test_v1_8_0_deprecate_trainer_data_loading_mixin(): trainer = Trainer(max_epochs=1) model = BoringModel() dm = BoringDataModule() trainer.fit(model, datamodule=dm) with pytest.deprecated_call( match=r"`TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 and will be removed in v1.8.", ): trainer.prepare_dataloader(dataloader=model.train_dataloader, shuffle=False) with pytest.deprecated_call( match=r"`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8.", ): trainer.request_dataloader(stage=RunningStage.TRAINING)
def test_ddp_spawn_add_get_queue(tmpdir): """Tests add_to_queue/get_from_queue with DDPSpawnStrategy.""" ddp_spawn_strategy = TestDDPSpawnStrategy() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, strategy=ddp_spawn_strategy) val: float = 1.0 val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert ddp_spawn_strategy.new_test_val == "new_test_val"
def test_ddp_spawn_extra_parameters(tmpdir): """Tests if device is set correctely when training for DDPSpawnPlugin.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn") assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) assert trainer.training_type_plugin.on_gpu assert trainer.training_type_plugin.root_device == torch.device("cuda:0") val: float = 1.0 val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert model.test_val == "test_val"
def test_init_optimizers_during_evaluation_and_prediction(tmpdir, fn): """Test that optimizers is an empty list during evaluation and prediction.""" class TestModel(BoringModel): def configure_optimizers(self): optimizer1 = optim.Adam(self.parameters(), lr=0.1) optimizer2 = optim.Adam(self.parameters(), lr=0.1) lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=1) lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, step_size=1) return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) train_fn = getattr(trainer, fn) train_fn(TestModel(), datamodule=BoringDataModule(), ckpt_path=None) assert len(trainer.lr_schedulers) == 0 assert len(trainer.optimizers) == 0 assert len(trainer.optimizer_frequencies) == 0
def test_ddp_spawn_extra_parameters(tmpdir): """Tests if device is set correctly when training for DDPSpawnStrategy and tests add_to_queue/get_from_queue with Lightning Module (deprecated way).""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, strategy="ddp_spawn") assert isinstance(trainer.strategy, DDPSpawnStrategy) assert trainer.strategy.root_device == torch.device("cuda:0") val: float = 1.0 val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert model.test_val == "test_val"
def test_ddp_spawn_add_get_queue(tmpdir): """Tests add_to_queue/get_from_queue with DDPSpawnPlugin.""" ddp_spawn_plugin = TestDDPSpawnPlugin() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, strategy=ddp_spawn_plugin) val: float = 1.0 val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() with pytest.deprecated_call( match="add_to_queue` method was deprecated in v1.5"): trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert ddp_spawn_plugin.new_test_val == "new_test_val"
def reset_instances(self): warning_cache.clear() return BoringDataModule(), BoringModel(), Trainer()
def __init__(self): super().__init__() self.train_dataloader = None self.val_dataloader = None self.test_dataloader = None self.predict_dataloader = None @pytest.mark.parametrize( "instance,available", [ (None, True), (BoringModel().train_dataloader(), True), (BoringModel(), True), (NoDataLoaderModel(), False), (BoringDataModule(), True), ], ) def test_dataloader_source_available(instance, available): """Test the availability check for _DataLoaderSource.""" source = _DataLoaderSource(instance=instance, name="train_dataloader") assert source.is_defined() is available def test_dataloader_source_direct_access(): """Test requesting a dataloader when the source is already a dataloader.""" dataloader = BoringModel().train_dataloader() source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() assert source.is_defined() assert source.dataloader() is dataloader