Пример #1
0
def test_trainer_save_checkpoint_storage_options(tmpdir):
    """This test validates that storage_options argument is properly passed to ``CheckpointIO``"""
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        fast_dev_run=True,
        enable_checkpointing=False,
    )
    trainer.fit(model)
    instance_path = tmpdir + "/path.ckpt"
    instance_storage_options = "my instance storage options"

    with mock.patch(
            "pytorch_lightning.plugins.io.torch_plugin.TorchCheckpointIO.save_checkpoint"
    ) as io_mock:
        trainer.save_checkpoint(instance_path,
                                storage_options=instance_storage_options)
        io_mock.assert_called_with(ANY,
                                   instance_path,
                                   storage_options=instance_storage_options)
        trainer.save_checkpoint(instance_path)
        io_mock.assert_called_with(ANY, instance_path, storage_options=None)

    with mock.patch(
            "pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint"
    ) as cc_mock:
        trainer.save_checkpoint(instance_path, True)
        cc_mock.assert_called_with(instance_path,
                                   weights_only=True,
                                   storage_options=None)
        trainer.save_checkpoint(instance_path, False, instance_storage_options)
        cc_mock.assert_called_with(instance_path,
                                   weights_only=False,
                                   storage_options=instance_storage_options)

    torch_checkpoint_io = TorchCheckpointIO()
    with pytest.raises(
            TypeError,
            match=
            r"`Trainer.save_checkpoint\(..., storage_options=...\)` with `storage_options` arg"
            f" is not supported for `{torch_checkpoint_io.__class__.__name__}`. Please implement your custom `CheckpointIO`"
            " to define how you'd like to use `storage_options`.",
    ):
        torch_checkpoint_io.save_checkpoint(
            {}, instance_path, storage_options=instance_storage_options)
    xla_checkpoint_io = XLACheckpointIO()
    with pytest.raises(
            TypeError,
            match=
            r"`Trainer.save_checkpoint\(..., storage_options=...\)` with `storage_options` arg"
            f" is not supported for `{xla_checkpoint_io.__class__.__name__}`. Please implement your custom `CheckpointIO`"
            " to define how you'd like to use `storage_options`.",
    ):
        xla_checkpoint_io.save_checkpoint(
            {}, instance_path, storage_options=instance_storage_options)
Пример #2
0
 def __init__(self,
              parallel_devices: Optional[List[int]] = None,
              checkpoint_io: Optional[CheckpointIO] = None,
              debug: bool = False,
              **_: Any) -> None:
     checkpoint_io = checkpoint_io or XLACheckpointIO()
     super().__init__(parallel_devices=parallel_devices,
                      checkpoint_io=checkpoint_io)
     self.debug = debug
     self.tpu_local_core_rank = 0
     self.tpu_global_core_rank = 0
     self.start_method = None
Пример #3
0
    def __init__(
        self,
        device: int,
        checkpoint_io: Optional[CheckpointIO] = None,
        debug: bool = False,
    ):

        device = xm.xla_device(device)
        checkpoint_io = checkpoint_io or XLACheckpointIO()
        super().__init__(device=device, checkpoint_io=checkpoint_io)

        self.debug = debug
        self.tpu_local_core_rank = 0
        self.tpu_global_core_rank = 0
Пример #4
0
    def __init__(
        self,
        device: int,
        accelerator: Optional[
            "pl.accelerators.accelerator.Accelerator"] = None,
        checkpoint_io: Optional[XLACheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
        debug: bool = False,
    ):
        device = xm.xla_device(device)
        checkpoint_io = checkpoint_io or XLACheckpointIO()
        super().__init__(accelerator=accelerator,
                         device=device,
                         checkpoint_io=checkpoint_io,
                         precision_plugin=precision_plugin)

        self.debug = debug
        self.tpu_local_core_rank = 0
        self.tpu_global_core_rank = 0
Пример #5
0
 def __init__(
     self,
     accelerator: Optional[
         "pl.accelerators.accelerator.Accelerator"] = None,
     parallel_devices: Optional[List[int]] = None,
     checkpoint_io: Optional[XLACheckpointIO] = None,
     precision_plugin: Optional[PrecisionPlugin] = None,
     debug: bool = False,
     **_: Any,
 ) -> None:
     checkpoint_io = checkpoint_io or XLACheckpointIO()
     super().__init__(
         accelerator=accelerator,
         parallel_devices=parallel_devices,
         checkpoint_io=checkpoint_io,
         precision_plugin=precision_plugin,
     )
     self.debug = debug
     self.tpu_local_core_rank = 0
     self.tpu_global_core_rank = 0
     self.start_method = "fork"
 def checkpoint_io(self) -> CheckpointIO:
     if self._checkpoint_io is None:
         self._checkpoint_io = XLACheckpointIO()
     return self._checkpoint_io