def require_apex(test_case): """ Decorator marking a test that requires apex """ if not is_apex_available(): return unittest.skip("test requires apex")(test_case) else: return test_case
def is_cuda_and_apex_available(): is_using_cuda = torch.cuda.is_available() and torch_device == "cuda" return is_using_cuda and is_apex_available()