def test_tpu_invalid_raises():
    strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin())
    with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
        Trainer(strategy=strategy, devices=8)

    strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin())
    with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"):
        Trainer(strategy=strategy, devices=8)
示例#2
0
def test_tpu_invalid_raises():
    accelerator = TPUAccelerator(object(), TPUSpawnPlugin())
    with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
        accelerator.setup(object())

    accelerator = TPUAccelerator(TPUPrecisionPlugin(), object())
    with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
        accelerator.setup(object())
示例#3
0
def test_tpu_invalid_raises():
    training_type_plugin = TPUSpawnStrategy(accelerator=TPUAccelerator(),
                                            precision_plugin=Mock())
    with pytest.raises(
            ValueError,
            match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"
    ):
        Trainer(strategy=training_type_plugin)

    training_type_plugin = DDPStrategy(accelerator=TPUAccelerator(),
                                       precision_plugin=TPUPrecisionPlugin())
    with pytest.raises(
            ValueError,
            match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"
    ):
        Trainer(strategy=training_type_plugin)
 def _map_devices_to_accelerator(self, accelerator: str) -> bool:
     if self.devices is None:
         return False
     if accelerator == _AcceleratorType.TPU and _TPU_AVAILABLE:
         if self.devices == "auto":
             self.devices = TPUAccelerator.auto_device_count()
         self.tpu_cores = device_parser.parse_tpu_cores(self.devices)
         return True
     if accelerator == _AcceleratorType.IPU and _IPU_AVAILABLE:
         if self.devices == "auto":
             self.devices = IPUAccelerator.auto_device_count()
         self.ipus = self.devices
         return True
     if accelerator == _AcceleratorType.GPU and torch.cuda.is_available():
         if self.devices == "auto":
             self.devices = GPUAccelerator.auto_device_count()
         self.gpus = self.devices
         self.parallel_device_ids = device_parser.parse_gpu_ids(
             self.devices)
         return True
     if accelerator == _AcceleratorType.CPU:
         if self.devices == "auto":
             self.devices = CPUAccelerator.auto_device_count()
         if not isinstance(self.devices, int):
             raise MisconfigurationException(
                 "The flag `devices` must be an int with `accelerator='cpu'`,"
                 f" got `devices={self.devices}` instead.")
         self.num_processes = self.devices
         return True
     return False
def test_accelerator_tpu(accelerator, devices):
    assert TPUAccelerator.is_available()

    trainer = Trainer(accelerator=accelerator, devices=devices)
    assert isinstance(trainer.accelerator, TPUAccelerator)
    assert isinstance(trainer.strategy, TPUSpawnStrategy)
    assert trainer.num_devices == 8
示例#6
0
def test_accelerator_cpu_with_tpu_cores_flag():
    assert TPUAccelerator.is_available()

    trainer = Trainer(accelerator="cpu", devices=8)
    assert isinstance(trainer.accelerator, CPUAccelerator)

    trainer = Trainer(accelerator="tpu", devices=8)
    assert isinstance(trainer.accelerator, TPUAccelerator)
    assert isinstance(trainer.strategy, TPUSpawnStrategy)
示例#7
0
def test_tpu_invalid_raises_set_precision_with_strategy():
    accelerator = TPUAccelerator()
    training_type_plugin = TPUSpawnStrategy(accelerator=accelerator,
                                            precision_plugin=object())
    with pytest.raises(
            ValueError,
            match=
            "`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):
        Trainer(strategy=training_type_plugin)

    accelerator = TPUAccelerator()
    training_type_plugin = DDPStrategy(accelerator=accelerator,
                                       precision_plugin=TPUPrecisionPlugin())
    with pytest.raises(
            ValueError,
            match=
            "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy"
    ):
        Trainer(strategy=training_type_plugin)