def test_accelerator_tpu(accelerator, devices): assert TPUAccelerator.is_available() trainer = Trainer(accelerator=accelerator, devices=devices) assert isinstance(trainer.accelerator, TPUAccelerator) assert isinstance(trainer.strategy, TPUSpawnStrategy) assert trainer.num_devices == 8
def test_accelerator_cpu_with_tpu_cores_flag(): assert TPUAccelerator.is_available() trainer = Trainer(accelerator="cpu", devices=8) assert isinstance(trainer.accelerator, CPUAccelerator) trainer = Trainer(accelerator="tpu", devices=8) assert isinstance(trainer.accelerator, TPUAccelerator) assert isinstance(trainer.strategy, TPUSpawnStrategy)