def requires_nccl_version(version, msg): if not c10d.is_nccl_available(): return sandcastle_skip("c10d was not compiled with the NCCL backend", ) else: return sandcastle_skip_if( torch.cuda.nccl.version() < version, "Requires NCCL version greater than or equal to: {}, found: {}, reason: {}" .format(version, torch.cuda.nccl.version(), msg), )
def skip_if_win32(): return sandcastle_skip_if( sys.platform == 'win32', "This unit test case is not supportted on Windows platform", )
def requires_mpi(): return sandcastle_skip_if( not c10d.is_mpi_available(), "c10d was not compiled with the MPI backend", )
def requires_nccl(): return sandcastle_skip_if( not c10d.is_nccl_available(), "c10d was not compiled with the NCCL backend", )
def requires_gloo(): return sandcastle_skip_if( not c10d.is_gloo_available(), "c10d was not compiled with the Gloo backend", )
from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, ) from torch.testing._internal.common_cuda import CUDA11OrLater try: import torchvision HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = sandcastle_skip_if(not HAS_TORCHVISION, "no torchvision") if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) # Various mixed precision configs to test under. default_mp = MixedPrecision( param_dtype=torch.float16,