def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyDDP(DDPPlugin): pass class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator_backend.ddp_plugin, MyDDP) raise RuntimeError('finished plugin check') model = BoringModel() trainer = Trainer( fast_dev_run=True, gpus=gpus, num_processes=num_processes, accelerator=ddp_backend, plugins=[MyDDP()], callbacks=[CB()], ) with pytest.raises(RuntimeError, match='finished plugin check'): trainer.fit(model)
def test_accelerator_choice_ddp_slurm(setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer.accelerator_connector.is_slurm_managing_tasks assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank( ) == 1 assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp', gpus=2, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_mixed_precision(tmpdir, hmp_params: dict): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: assert trainer.strategy.model.precision == "bf16" raise SystemExit model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, accelerator="hpu", devices=1, plugins=[HPUPrecisionPlugin(precision="bf16", **hmp_params)], callbacks=TestCallback(), ) assert isinstance(trainer.strategy, SingleHPUStrategy) assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin) assert trainer.strategy.precision_plugin.precision == "bf16" with pytest.raises(SystemExit): trainer.fit(model)
def test_strategy_choice_ddp_kubeflow(set_device_mock, device_count_mock, setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank( ) == 0 assert trainer.training_type_plugin.local_rank == 0 raise SystemExit() model = BoringModel() trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=1, callbacks=[CB()]) with pytest.raises(SystemExit): trainer.fit(model) set_device_mock.assert_called_once()
def test_accelerator_choice_ddp_cpu_kubeflow(device_count_mock, setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank( ) == 0 assert trainer.training_type_plugin.task_idx == 0 raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', num_processes=1, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp assert trainer.accelerator_connector.is_slurm_managing_tasks assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) assert trainer.training_type_plugin.task_idx == 0 raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', num_processes=2, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_custom_model_summary_callback_summarize(tmpdir): class CustomModelSummary(ModelSummary): @staticmethod def summarize( summary_data: List[List[Union[str, List[str]]]], total_parameters: int, trainable_parameters: int, model_size: float, ) -> None: assert summary_data[1][0] == "Name" assert summary_data[1][1][0] == "layer" assert summary_data[2][0] == "Type" assert summary_data[2][1][0] == "Linear" assert summary_data[3][0] == "Params" assert total_parameters == 66 assert trainable_parameters == 66 model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, callbacks=CustomModelSummary(), max_steps=1) trainer.fit(model)
def test_accelerator_choice_ddp2_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer._distrib_type == DistributedType.DDP2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 assert trainer.accelerator_backend.cluster_environment.local_rank( ) == trainer.accelerator_backend.task_idx raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp2', gpus=2, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_sharded_ddp_choice(tmpdir, accelerator): """ Test to ensure that plugin is correctly chosen """ class CB(Callback): def on_fit_start(self, trainer, pl_module): if accelerator == 'ddp_sharded': assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) elif accelerator == 'ddp_sharded_spawn': assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin) raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator=accelerator, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyNativeAMP(NativeAMPPlugin): pass class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.precision_connector.backend, MyNativeAMP) raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, precision=16, amp_backend='native', gpus=gpus, num_processes=num_processes, accelerator=ddp_backend, plugins=[MyNativeAMP()], callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): """Ensure if we use a config and turn off cpu_offload, that this is set to False within the config.""" deepspeed_zero_config['zero_optimization']['cpu_offload'] = False class TestCallback(Callback): def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None: assert trainer.training_type_plugin.config['zero_optimization'][ 'cpu_offload'] is False raise SystemExit() model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1, callbacks=[TestCallback()]) with pytest.raises(SystemExit): trainer.fit(model)
def test_swa_deepcopy(tmpdir): """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM.""" class TestSWA(StochasticWeightAveraging): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.on_before_accelerator_backend_setup_called = False def on_before_accelerator_backend_setup(self, trainer: "Trainer", pl_module: "LightningModule"): super().on_before_accelerator_backend_setup(trainer, pl_module) assert self._average_model.train_dataloader is not pl_module.train_dataloader assert self._average_model.train_dataloader.__self__ == self._average_model assert isinstance(pl_module.train_dataloader, _PatchDataLoader) assert self._average_model.trainer is None self.on_before_accelerator_backend_setup_called = True model = BoringModel() swa = TestSWA() trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True) trainer.fit(model, train_dataloader=DataLoader(RandomDataset(32, 2))) assert swa.on_before_accelerator_backend_setup_called
def test_poptorch_models_at_different_stages(tmpdir): plugin = IPUPlugin() trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, ipus=8) model = BoringModel() model.trainer = trainer plugin.model = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.pre_dispatch() assert list(trainer.training_type_plugin.poptorch_models) == [ RunningStage.TRAINING, RunningStage.VALIDATING ] for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), (TrainerFn.TESTING, RunningStage.TESTING), (TrainerFn.PREDICTING, RunningStage.PREDICTING), ): trainer.state.fn = fn trainer.state.stage = stage trainer.training_type_plugin.pre_dispatch() assert list(trainer.training_type_plugin.poptorch_models) == [stage]
def test_accelerator_choice_ddp2_te(device_count_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer.use_ddp2 assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 assert trainer.training_type_plugin.task_idx == 10 raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp2', gpus=2, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyApexPlugin(ApexMixedPrecisionPlugin): pass class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.precision_plugin, MyApexPlugin) raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, precision=16, amp_backend='apex', gpus=gpus, num_processes=num_processes, accelerator=ddp_backend, plugins=[MyApexPlugin(amp_level="O2")], callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_custom_accelerator(tmpdir): class Accel(Accelerator): def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None: pass class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator_backend, Accel) raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator=Accel(), num_processes=2, callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_combined_dataloader_for_training_with_ddp( replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool ): """When providing a CombinedLoader as the training data, it should be correctly receive the distributed samplers.""" mode = "min_size" if is_min_size_mode else "max_size_cycle" dim = 3 n1 = 8 n2 = 6 dataloader = { "a": DataLoader(RandomDataset(dim, n1), batch_size=1), "b": DataLoader(RandomDataset(dim, n2), batch_size=1), } if use_combined_loader: dataloader = CombinedLoader(dataloader, mode=mode) model = BoringModel() trainer = Trainer( strategy="ddp", accelerator="auto", devices="auto", replace_sampler_ddp=replace_sampler_ddp, multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode, ) trainer._data_connector.attach_data( model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None ) expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2) expected_length_after_ddp = ( math.ceil(expected_length_before_ddp / trainer.num_devices) if replace_sampler_ddp else expected_length_before_ddp ) trainer.reset_train_dataloader(model=model) assert trainer.train_dataloader is not None assert isinstance(trainer.train_dataloader, CombinedLoader) assert trainer.train_dataloader.mode == mode assert trainer.num_training_batches == expected_length_after_ddp
def test_specific_gpu_device_id(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert model.device.index == 1 def on_train_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, ) -> None: assert batch.device.index == 1 def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert model.device.index == 1 def on_test_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int, ) -> None: assert batch.device.index == 1 model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=[1], strategy="deepspeed", callbacks=TestCallback()) trainer.fit(model) trainer.test(model)
def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch): """Test "strict" support in lr_scheduler dict.""" model = BoringModel() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) max_epochs = 1 if complete_epoch else None max_steps = -1 if complete_epoch else 1 trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps) model.configure_optimizers = lambda: { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": True}, } if complete_epoch: with pytest.raises( MisconfigurationException, match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:", ): trainer.fit(model) else: trainer.fit(model) step_mock.assert_not_called() model.configure_optimizers = lambda: { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": False}, } if complete_epoch: with pytest.warns( RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict" ): trainer.fit(model) step_mock.assert_not_called()
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): world_size = 2 total_train_samples = 16 train_batch_size = 4 total_val_samples = 2 val_batch_size = 1 train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size) val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, max_epochs=1, val_check_interval=val_check_interval, accelerator="gpu", devices=world_size, strategy="ddp", enable_progress_bar=True, enable_model_summary=False, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) total_train_batches = total_train_samples // (train_batch_size * world_size) val_check_batch = max(1, int(total_train_batches * val_check_interval)) assert trainer.val_check_batch == val_check_batch val_checks_per_epoch = total_train_batches / val_check_batch total_val_batches = total_val_samples // (val_batch_size * world_size) pbar_callback = trainer.progress_bar_callback if trainer.is_global_zero: assert pbar_callback.val_progress_bar.n == total_val_batches assert pbar_callback.val_progress_bar.total == total_val_batches total_val_batches = total_val_batches * val_checks_per_epoch assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size assert pbar_callback.is_enabled
def test_accelerator_choice_ddp_slurm(set_device_mock, device_count_mock, setup_distributed_mock): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer._accelerator_connector._is_slurm_managing_tasks assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank( ) == 1 assert trainer.training_type_plugin.local_rank == 1 raise SystemExit() model = BoringModel() with pytest.deprecated_call( match=r"accelerator='ddp'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2, callbacks=[CB()]) with pytest.raises(SystemExit): trainer.fit(model)
def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero): """Test that the progress bar is disabled when not in global rank zero.""" pbar = TQDMProgressBar() model = BoringModel() trainer = Trainer( callbacks=[pbar], fast_dev_run=True, ) pbar.enable() trainer.fit(model) assert pbar.is_disabled pbar.enable() trainer.predict(model) assert pbar.is_disabled pbar.enable() trainer.validate(model) assert pbar.is_disabled pbar.enable() trainer.test(model) assert pbar.is_disabled
def test_progress_bar_max_val_check_interval(tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates): limit_batches = 7 model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, val_check_interval=val_check_interval, limit_train_batches=limit_batches, limit_val_batches=limit_batches, callbacks=TQDMProgressBar(refresh_rate=3), ) with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) pbar = trainer.progress_bar_callback assert pbar.main_progress_bar.n_values == main_progress_bar_updates assert pbar.val_progress_bar.n_values == val_progress_bar_updates val_check_batch = (max(1, int( limit_batches * val_check_interval)) if isinstance( val_check_interval, float) else val_check_interval) assert trainer.val_check_batch == val_check_batch val_checks_per_epoch = math.ceil(limit_batches // val_check_batch) pbar_callback = trainer.progress_bar_callback total_val_batches = limit_batches * val_checks_per_epoch assert pbar_callback.val_progress_bar.n == limit_batches assert pbar_callback.val_progress_bar.total == limit_batches assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches assert pbar_callback.is_enabled
def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer._accelerator_connector._is_slurm_managing_tasks assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank( ) == 1 assert trainer.training_type_plugin.local_rank == 1 raise SystemExit() model = BoringModel() trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2, callbacks=[CB()]) with pytest.raises(SystemExit): trainer.fit(model) set_device_mock.assert_called_once()
def test_progress_bar_fast_dev_run(tmpdir): model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, ) trainer.fit(model) progress_bar = trainer.progress_bar_callback assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded assert 1 == progress_bar.total_val_batches assert 1 == progress_bar.train_batch_idx assert 1 == progress_bar.val_batch_idx assert 0 == progress_bar.test_batch_idx # the main progress bar should display 2 batches (1 train, 1 val) assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n trainer.validate(model) # the validation progress bar should display 1 batch assert 1 == progress_bar.val_batch_idx assert 1 == progress_bar.val_progress_bar.total assert 1 == progress_bar.val_progress_bar.n trainer.test(model) # the test progress bar should display 1 batch assert 1 == progress_bar.test_batch_idx assert 1 == progress_bar.test_progress_bar.total assert 1 == progress_bar.test_progress_bar.n
def test_deepspeed_config(tmpdir, deepspeed_zero_config): """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers and saves the model weights to load correctly.""" class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: from deepspeed.runtime.lr_schedules import WarmupLR from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) assert isinstance(trainer.lr_scheduler_configs[0].scheduler, WarmupLR) assert trainer.lr_scheduler_configs[0].interval == "step" assert trainer.lr_scheduler_configs[0].opt_idx == 0 model = BoringModel() lr_monitor = LearningRateMonitor() trainer = Trainer( strategy=DeepSpeedStrategy(config=deepspeed_zero_config), default_root_dir=tmpdir, accelerator="gpu", devices=1, log_every_n_steps=1, limit_train_batches=4, limit_val_batches=4, limit_test_batches=4, max_epochs=2, precision=16, callbacks=[TestCB(), lr_monitor], ) trainer.fit(model) trainer.test(model) assert list(lr_monitor.lrs) == ["lr-SGD"] assert len(set(lr_monitor.lrs["lr-SGD"])) == 8
def test_deepspeed_setup_train_dataloader(tmpdir): """Test DeepSpeed works when setup is required to call in the DataModule.""" class TestSetupIsCalledDataModule(LightningDataModule): def __init__(self): super().__init__() self._setup = False def setup(self, stage: Optional[str] = None) -> None: self._setup = True def train_dataloader(self): assert self._setup return DataLoader(RandomDataset(32, 64), batch_size=2) def val_dataloader(self): assert self._setup return DataLoader(RandomDataset(32, 64), batch_size=2) def test_dataloader(self): assert self._setup return DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, strategy=DeepSpeedStrategy(logging_level=logging.INFO), accelerator="gpu", devices=1, fast_dev_run=True, ) dm = TestSetupIsCalledDataModule() with mock.patch("deepspeed.utils.logging.logger.warning", autospec=True) as mock_object: trainer.fit(model, datamodule=dm) assert any("Tried to infer the batch size" in str(arg) for arg in mock_object.call_args_list)
def test_main_progress_bar_update_amount(tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_deltas: list, val_deltas: list): """ Test that the main progress updates with the correct amount together with the val progress. At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. """ model = BoringModel() progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=train_batches, limit_val_batches=val_batches, callbacks=[progress_bar], logger=False, checkpoint_callback=False, ) trainer.fit(model) progress_bar.main_progress_bar.update.assert_has_calls( [call(delta) for delta in train_deltas]) if val_batches > 0: progress_bar.val_progress_bar.update.assert_has_calls( [call(delta) for delta in val_deltas])
def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): """Ensure if we use a config and turn off offload_optimizer, that this is set to False within the config.""" deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False class TestCallback(Callback): def setup(self, trainer, pl_module, stage=None) -> None: assert trainer.strategy.config["zero_optimization"][ "offload_optimizer"] is False raise SystemExit() model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, strategy=DeepSpeedStrategy(config=deepspeed_zero_config), precision=16, gpus=1, callbacks=[TestCallback()], ) with pytest.raises(SystemExit): trainer.fit(model)
def test_no_val_on_train_epoch_loop_restart(tmpdir): """Test that training validation loop doesn't get triggered at the beginning of a restart.""" trainer_kwargs = { "max_epochs": 1, "limit_train_batches": 1, "limit_val_batches": 1, "num_sanity_val_steps": 0, "enable_checkpointing": False, } trainer = Trainer(**trainer_kwargs) model = BoringModel() trainer.fit(model) ckpt_path = str(tmpdir / "last.ckpt") trainer.save_checkpoint(ckpt_path) trainer_kwargs["max_epochs"] = 2 trainer = Trainer(**trainer_kwargs) with patch.object(trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance ) as advance_mocked: trainer.fit(model, ckpt_path=ckpt_path) assert advance_mocked.call_count == 1