def test_sync_dist(_): sync = _Sync(TPUSpawnStrategy().reduce, should=True, _op=torch.distributed.ReduceOp.SUM) value = torch.tensor([1.0]) value = (sync(value), ) assert value.item() == 8
def test_tpu_invalid_raises(): strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): Trainer(strategy=strategy, devices=8) strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): Trainer(strategy=strategy, devices=8)
def test_model_tpu_one_core(): """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) # assert training strategy attributes for device setting assert isinstance(trainer.strategy, TPUSpawnStrategy) assert trainer.strategy.root_device == torch.device("xla", index=1) model = BoringModelTPU() trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ
def test_tpu_invalid_raises(): training_type_plugin = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=Mock()) with pytest.raises( ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin" ): Trainer(strategy=training_type_plugin) training_type_plugin = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`" ): Trainer(strategy=training_type_plugin)
def test_error_iterable_dataloaders_passed_to_fit( _, tmpdir, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders ): """Test that the TPUSpawnStrategy identifies dataloaders with iterable datasets and fails early.""" trainer = Trainer() model = BoringModelNoDataloaders() model.trainer = trainer trainer._data_connector.attach_dataloaders( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, test_dataloaders=test_dataloaders, predict_dataloaders=predict_dataloaders, ) with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnStrategy(MagicMock()).connect(model)
def test_tpu_invalid_raises_set_precision_with_strategy(): accelerator = TPUAccelerator() training_type_plugin = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=object()) with pytest.raises( ValueError, match= "`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): Trainer(strategy=training_type_plugin) accelerator = TPUAccelerator() training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match= "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" ): Trainer(strategy=training_type_plugin)
def test_tpu_debug_mode(tmpdir): """Test if debug mode works on TPU.""" class DebugModel(BoringModel): def on_train_start(self): assert os.environ.get("PT_XLA_DEBUG") == str( 1), "PT_XLA_DEBUG was not set in environment variables" def teardown(self, stage): assert "PT_XLA_DEBUG" not in os.environ tutils.reset_seed() trainer_options = dict( default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=4, tpu_cores=8, limit_train_batches=0.4, limit_val_batches=0.4, strategy=TPUSpawnStrategy(debug=True), ) model = DebugModel() tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
def test_mp_device_dataloader_attribute(_): dataset = RandomDataset(32, 64) dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset)) assert dataloader.dataset == dataset
def test_strategy_choice_tpu_strategy(tmpdir): trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) assert isinstance(trainer.strategy, TPUSpawnStrategy)
def test_error_process_iterable_dataloader(_): with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len)
def test_device_type_when_training_plugin_tpu_passed(tmpdir): trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) assert isinstance(trainer.strategy, TPUSpawnStrategy) assert isinstance(trainer.accelerator, TPUAccelerator)
def test_device_type_when_training_plugin_tpu_passed(tmpdir): trainer = Trainer(strategy=TPUSpawnStrategy(), tpu_cores=8) assert isinstance(trainer.strategy, TPUSpawnStrategy) assert trainer._device_type == _AcceleratorType.TPU assert isinstance(trainer.accelerator, TPUAccelerator)