Ejemplo n.º 1
0
        for i in range(len(mods)):
            module_path = '.'.join(mods[:i + 1])
            if importlib.util.find_spec(module_path) is None:
                return False
        return True
    except AttributeError:
        return False


APEX_AVAILABLE = _module_available("apex.amp")
NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
OMEGACONF_AVAILABLE = _module_available("omegaconf")
HYDRA_AVAILABLE = _module_available("hydra")
HOROVOD_AVAILABLE = _module_available("horovod.torch")

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
FAIRSCALE_PIPE_AVAILABLE = FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) == LooseVersion("1.6.0")
BOLTS_AVAILABLE = _module_available('pl_bolts')

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps


class LightningEnum(str, Enum):
    """ Type of any enumerator with allowed comparison to string invariant to cases. """

    @classmethod
def test_tpu_device_absence():
    """Check tpu_device_exists returns None when torch_xla is not available"""
    assert xdu.tpu_device_exists() is None
def test_tpu_device_presence():
    """Check tpu_device_exists returns True when TPU is available"""
    assert xdu.tpu_device_exists() is True
Ejemplo n.º 4
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base.boring_model import BoringModel
from tests.base.develop_utils import pl_multi_process_test


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(),
                    reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu(tmpdir):
    """ Checks if training can be resumed from a saved checkpoint on CPU"""

    # Train a model on TPU
    model = BoringModel()
    trainer = Trainer(
        checkpoint_callback=True,
        max_epochs=1,
        tpu_cores=8,
    )
    trainer.fit(model)

    model_path = trainer.checkpoint_callback.best_model_path