def test_poptorch_models_at_different_stages(tmpdir): plugin = IPUStrategy() trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, accelerator="ipu", devices=8) model = BoringModel() model.trainer = trainer plugin.model = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING trainer.strategy.setup(trainer) assert list(trainer.strategy.poptorch_models) == [ RunningStage.TRAINING, RunningStage.VALIDATING ] for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), (TrainerFn.TESTING, RunningStage.TESTING), (TrainerFn.PREDICTING, RunningStage.PREDICTING), ): trainer.state.fn = fn trainer.state.stage = stage trainer.strategy.setup(trainer) assert list(trainer.strategy.poptorch_models) == [stage]
def test_dataloader_source_request_from_module(): """Test requesting a dataloader from a module works.""" module = BoringModel() module.trainer = Trainer() module.foo = Mock(return_value=module.train_dataloader()) source = _DataLoaderSource(module, "foo") assert source.is_module() module.foo.assert_not_called() assert isinstance(source.dataloader(), DataLoader) module.foo.assert_called_once()
def test_eval_limit_batches(stage, mode, limit_batches): limit_eval_batches = f"limit_{mode}_batches" dl_hook = f"{mode}_dataloader" model = BoringModel() eval_loader = getattr(model, dl_hook)() trainer = Trainer(**{limit_eval_batches: limit_batches}) model.trainer = trainer trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model) expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches assert loader_num_batches[0] == expected_batches assert len(dataloaders[0]) == len(eval_loader)
def test_distributed_sampler_with_overfit_batches(): model = BoringModel() trainer = Trainer( overfit_batches=1, strategy="ddp_spawn", num_processes=2, ) model.trainer = trainer trainer.model = model trainer._data_connector.attach_dataloaders(model) trainer.reset_train_dataloader() train_sampler = trainer.train_dataloader.loaders.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle is False
def test_replication_factor(tmpdir): """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the dataloaders.""" plugin = IPUStrategy() trainer = Trainer(accelerator="ipu", devices=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin) assert isinstance(trainer.accelerator, IPUAccelerator) assert trainer.num_devices == 2 assert trainer.strategy.replication_factor == 2 model = BoringModel() training_opts = poptorch.Options() inference_opts = poptorch.Options() training_opts.replicationFactor(8) inference_opts.replicationFactor(7) plugin = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1, strategy=plugin) trainer.optimizers = model.configure_optimizers()[0] plugin.model = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING trainer.strategy.setup(trainer) trainer.state.stage = RunningStage.TRAINING assert trainer.strategy.replication_factor == 8 trainer.state.stage = RunningStage.VALIDATING assert trainer.strategy.replication_factor == 7 for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), (TrainerFn.TESTING, RunningStage.TESTING), (TrainerFn.PREDICTING, RunningStage.PREDICTING), ): trainer.state.fn = fn trainer.state.stage = stage trainer.strategy.setup(trainer) assert trainer.strategy.replication_factor == 7
def test_poptorch_models_at_different_stages(tmpdir): plugin = IPUPlugin() trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, ipus=8) model = BoringModel() model.trainer = trainer plugin.model = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.pre_dispatch() assert list(trainer.training_type_plugin.poptorch_models) == [ RunningStage.TRAINING, RunningStage.VALIDATING ] for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), (TrainerFn.TESTING, RunningStage.TESTING), (TrainerFn.PREDICTING, RunningStage.PREDICTING), ): trainer.state.fn = fn trainer.state.stage = stage trainer.training_type_plugin.pre_dispatch() assert list(trainer.training_type_plugin.poptorch_models) == [stage]
def test_default_level_for_hooks_that_support_logging(): def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs): for hook in hooks: model._current_fx_name = hook model.log(hook, 1) result_mock.assert_called_with(hook, hook, torch.tensor(1), on_step=on_step, on_epoch=on_epoch, **extra_kwargs) trainer = Trainer() model = BoringModel() model.trainer = trainer extra_kwargs = { k: ANY for k in inspect.signature(_ResultCollection.log).parameters if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"] } all_logging_hooks = { k for k in _FxValidator.functions if _FxValidator.functions[k] } with mock.patch( "pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log", return_value=None) as result_mock: trainer.state.stage = RunningStage.TRAINING hooks = [ "on_before_backward", "backward", "on_after_backward", "on_before_optimizer_step", "optimizer_step", "on_before_zero_grad", "optimizer_zero_grad", "training_step", "training_step_end", "on_batch_start", "on_batch_end", "on_train_batch_start", "on_train_batch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=True, on_epoch=False, extra_kwargs=extra_kwargs) hooks = [ "on_train_start", "on_train_epoch_start", "on_train_epoch_end", "on_epoch_start", "on_epoch_end", "training_epoch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) trainer.state.stage = RunningStage.VALIDATING trainer.state.fn = TrainerFn.VALIDATING hooks = [ "on_validation_start", "on_validation_epoch_start", "on_validation_epoch_end", "on_validation_batch_start", "on_validation_batch_end", "validation_step", "validation_step_end", "validation_epoch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) trainer.state.stage = RunningStage.TESTING trainer.state.fn = TrainerFn.TESTING hooks = [ "on_test_start", "on_test_epoch_start", "on_test_epoch_end", "on_test_batch_start", "on_test_batch_end", "test_step", "test_step_end", "test_epoch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) # just to ensure we checked all possible logging hooks here assert len(all_logging_hooks) == 0