コード例 #1
0
ファイル: imports.py プロジェクト: lujun59/nlp_learning
_IS_WINDOWS = platform.system() == "Windows"
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_QUANTIZE_AVAILABLE = bool(
    [eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available(
    'fairscale.nn.data_parallel')
_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version(
    "fairscale", operator.le, "0.1.3")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available(
    'torch.distributed.group')
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(
    torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")

from pytorch_lightning.utilities.xla_device import XLADeviceUtils  # noqa: E402

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
コード例 #2
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from tests.helpers.boring_model import BoringModel
from tests.helpers.utils import pl_multi_process_test


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(),
                    reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu(tmpdir):
    """ Checks if training can be resumed from a saved checkpoint on CPU"""
    # Train a model on TPU
    model = BoringModel()
    trainer = Trainer(
        checkpoint_callback=True,
        max_epochs=1,
        tpu_cores=8,
    )
    trainer.fit(model)

    model_path = trainer.checkpoint_callback.best_model_path