def instantiate_test(cls, name, test, *, generic_cls=None): def instantiate_test_helper(cls, name, *, test, dtype, op): # Constructs the test's name test_name = _construct_test_name(name, op, cls.device_type, dtype) # wraps instantiated test with op decorators # NOTE: test_wrapper exists because we don't want to apply # op-specific decorators to the original test. # Test-sepcific decorators are applied to the original test, # however. if op is not None and op.decorators is not None: @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) for decorator in op.decorators: test_wrapper = decorator(test_wrapper) test_fn = test_wrapper else: test_fn = test # Constructs the test @wraps(test) def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): if op is not None and op.should_skip( generic_cls.__name__, name, self.device_type, dtype): self.skipTest("Skipped!") device_arg: str = cls.get_primary_device() if hasattr(test_fn, 'num_required_devices'): device_arg = cls.get_all_devices() # Sets precision and runs test # Note: precision is reset after the test is run guard_precision = self.precision try: self.precision = self._get_precision_override( test_fn, dtype) args = (arg for arg in (device_arg, dtype, op) if arg is not None) result = test_fn(self, *args) except RuntimeError as rte: # check if rte should stop entire test suite. self._stop_test_suite = self._should_stop_test_suite(rte) # raise the runtime error as is for the test suite to record. raise rte finally: self.precision = guard_precision return result assert not hasattr( cls, test_name), "Redefinition of test {0}".format(test_name) setattr(cls, test_name, instantiated_test) # Handles tests using the ops decorator if hasattr(test, "op_list"): for op in test.op_list: # Acquires dtypes, using the op data if unspecified dtypes = cls._get_dtypes(test) if dtypes is None: if test.opinfo_dtypes == OpDTypes.unsupported: dtypes = set(get_all_dtypes()).difference( op.supported_dtypes(cls.device_type)) elif test.opinfo_dtypes == OpDTypes.supported: dtypes = op.supported_dtypes(cls.device_type) elif test.opinfo_dtypes == OpDTypes.basic: dtypes = op.default_test_dtypes(cls.device_type) else: raise RuntimeError( f"Unknown OpDType: {test.opinfo_dtypes}") if test.allowed_dtypes is not None: dtypes = dtypes.intersection(test.allowed_dtypes) else: assert test.allowed_dtypes is None, "ops(allowed_dtypes=[...]) and the dtypes decorator are incompatible" assert test.opinfo_dtypes == OpDTypes.basic, "ops(dtypes=...) and the dtypes decorator are incompatible" for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=op) else: # Handles tests that don't use the ops decorator dtypes = cls._get_dtypes(test) dtypes = tuple(dtypes) if dtypes is not None else (None, ) for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)
def instantiate_test(cls, name, test, *, generic_cls=None): def instantiate_test_helper(cls, name, *, test, dtype, op): # Constructs the test's name test_name = _construct_test_name(name, op, cls.device_type, dtype) # Wraps instantiated test with op decorators # NOTE: test_wrapper exists because we don't want to apply # op-specific decorators to the original test. # Test-specific decorators are applied to the original test, # however. if op is not None: try: active_decorators = [] if op.should_skip(generic_cls.__name__, name, cls.device_type, dtype): active_decorators.append(skipIf(True, "Skipped!")) if op.decorators is not None: for decorator in op.decorators: # Can't use isinstance as it would cause a circular import if decorator.__class__.__name__ == 'DecorateInfo': if decorator.is_active(generic_cls.__name__, name, cls.device_type, dtype): active_decorators += decorator.decorators else: active_decorators.append(decorator) @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) for decorator in active_decorators: test_wrapper = decorator(test_wrapper) test_fn = test_wrapper except Exception as ex: # Provides an error message for debugging before rethrowing the exception print("Failed to instantiate {0} for op {1}!".format( test_name, op.name)) raise ex else: test_fn = test # Constructs the test @wraps(test_fn) def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): device_arg: str = cls.get_primary_device() if hasattr(test_fn, 'num_required_devices'): device_arg = cls.get_all_devices() # Sets precision and runs test # Note: precision is reset after the test is run guard_precision = self.precision try: self.precision = self._get_precision_override( test_fn, dtype) args = (arg for arg in (device_arg, dtype, op) if arg is not None) result = test_fn(self, *args) except RuntimeError as rte: # check if rte should stop entire test suite. self._stop_test_suite = self._should_stop_test_suite() # raise the runtime error as is for the test suite to record. raise rte finally: self.precision = guard_precision return result assert not hasattr( cls, test_name), "Redefinition of test {0}".format(test_name) setattr(cls, test_name, instantiated_test) # Handles tests using the ops decorator if hasattr(test, "op_list"): for op in test.op_list: # Acquires dtypes, using the op data if unspecified dtypes = cls._get_dtypes(test) if dtypes is None: if test.opinfo_dtypes == OpDTypes.unsupported_backward: dtypes = set(get_all_dtypes()).difference( op.supported_backward_dtypes(cls.device_type)) elif test.opinfo_dtypes == OpDTypes.supported_backward: dtypes = op.supported_backward_dtypes(cls.device_type) elif test.opinfo_dtypes == OpDTypes.unsupported: dtypes = set(get_all_dtypes()).difference( op.supported_dtypes(cls.device_type)) elif test.opinfo_dtypes == OpDTypes.supported: dtypes = op.supported_dtypes(cls.device_type) elif test.opinfo_dtypes == OpDTypes.basic: dtypes = op.default_test_dtypes(cls.device_type) elif test.opinfo_dtypes == OpDTypes.none: dtypes = _NO_DTYPES else: raise RuntimeError( f"Unknown OpDType: {test.opinfo_dtypes}") if test.allowed_dtypes is not None: dtypes = dtypes.intersection(test.allowed_dtypes) else: assert test.allowed_dtypes is None, "ops(allowed_dtypes=[...]) and the dtypes decorator are incompatible" assert test.opinfo_dtypes == OpDTypes.basic, "ops(dtypes=...) and the dtypes decorator are incompatible" if dtypes is _NO_DTYPES: instantiate_test_helper(cls, name, test=test, dtype=_NO_DTYPES, op=op) else: for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=op) else: # Handles tests that don't use the ops decorator dtypes = cls._get_dtypes(test) dtypes = tuple(dtypes) if dtypes is not None else (None, ) for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)
import torch from torch.testing import get_all_dtypes dtypes = get_all_dtypes() dtypes.remove(torch.half) grad_dtypes = [torch.float, torch.double] devices = [torch.device('cpu')] if torch.cuda.is_available(): devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))] def tensor(x, dtype, device): return None if x is None else torch.tensor(x, dtype=dtype, device=device)
def instantiate_test(cls, name, test, *, generic_cls=None): def instantiate_test_helper(cls, name, *, test, dtype, op): # Constructs the test's name test_name = _construct_test_name(name, op, cls.device_type, dtype) # wraps instantiated test with op decorators # NOTE: test_wrapper exists because we don't want to apply # op-specific decorators to the original test. # Test-sepcific decorators are applied to the original test, # however. if op is not None and op.decorators is not None: @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) for decorator in op.decorators: test_wrapper = decorator(test_wrapper) test_fn = test_wrapper else: test_fn = test # Constructs the test @wraps(test) def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): if op is not None and op.should_skip( generic_cls.__name__, name, self.device_type, dtype): self.skipTest("Skipped!") device_arg: str = cls.get_primary_device() if hasattr(test_fn, 'num_required_devices'): device_arg = cls.get_all_devices() # Sets precision and runs test # Note: precision is reset after the test is run guard_precision = self.precision try: self.precision = self._get_precision_override( test_fn, dtype) args = (arg for arg in (device_arg, dtype, op) if arg is not None) result = test_fn(self, *args) finally: self.precision = guard_precision return result assert not hasattr( cls, test_name), "Redefinition of test {0}".format(test_name) setattr(cls, test_name, instantiated_test) # Handles tests using the ops decorator if hasattr(test, "op_list"): for op in test.op_list: # Acquires dtypes, using the op data if unspecified dtypes = cls._get_dtypes(test) if dtypes is None: if cls.device_type == 'cpu' and op.dtypesIfCPU is not None: dtypes = op.dtypesIfCPU elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM and op.dtypesIfCUDA is not None): dtypes = op.dtypesIfCUDA elif (cls.device_type == 'cuda' and TEST_WITH_ROCM and op.dtypesIfROCM is not None): dtypes = op.dtypesIfROCM else: dtypes = op.dtypes # Inverts dtypes if the function wants unsupported dtypes if test.unsupported_dtypes_only is True: dtypes = [d for d in get_all_dtypes() if d not in dtypes] dtypes = dtypes if dtypes is not None else (None, ) for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=op) else: # Handles tests that don't use the ops decorator dtypes = cls._get_dtypes(test) dtypes = tuple(dtypes) if dtypes is not None else (None, ) for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)
def test_dtypes(self, device, op): # dtypes to try to backward in allowed_backward_dtypes = floating_and_complex_types_and( torch.bfloat16, torch.float16) # lists for (un)supported dtypes supported_dtypes = [] unsupported_dtypes = [] supported_backward_dtypes = [] unsupported_backward_dtypes = [] def unsupported(dtype): unsupported_dtypes.append(dtype) if dtype in allowed_backward_dtypes: unsupported_backward_dtypes.append(dtype) for dtype in get_all_dtypes(): # tries to acquire samples - failure indicates lack of support requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd) try: samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) except Exception as e: unsupported(dtype) continue # Counts number of successful backward attempts # NOTE: This exists as a kludge because this only understands how to # request a gradient if the output is a tensor or a sequence with # a tensor as its first element. num_backward_successes = 0 for sample in samples: # tries to call operator with the sample - failure indicates # lack of support try: result = op(sample.input, *sample.args, **sample.kwargs) except Exception as e: # NOTE: some ops will fail in forward if their inputs # require grad but they don't support computing the gradient # in that type! This is a bug in the op! unsupported(dtype) # Short-circuits testing this dtype -- it doesn't work if dtype in unsupported_dtypes: break # Short-circuits if the dtype isn't a backward dtype or # it's already identified as not supported if dtype not in allowed_backward_dtypes or dtype in unsupported_backward_dtypes: continue # Checks for backward support in the same dtype try: result = sample.output_process_fn_grad(result) if isinstance(result, torch.Tensor): backward_tensor = result elif isinstance(result, Sequence) and isinstance( result[0], torch.Tensor): backward_tensor = result[0] else: continue # Note: this grad may not have the same dtype as dtype # For functions like complex (float -> complex) or abs # (complex -> float) the grad tensor will have a # different dtype than the input. # For simplicity, this is still modeled as these ops # supporting grad in the input dtype. grad = torch.randn_like(backward_tensor) backward_tensor.backward(grad) num_backward_successes += 1 except Exception as e: unsupported_backward_dtypes.append(dtype) if dtype not in unsupported_dtypes: supported_dtypes.append(dtype) if num_backward_successes > 0 and dtype not in unsupported_backward_dtypes: supported_backward_dtypes.append(dtype) # Checks that dtypes are listed correctly and generates an informative # error message device_type = torch.device(device).type claimed_supported = set(op.supported_dtypes(device_type)) supported_dtypes = set(supported_dtypes) supported_but_unclaimed = supported_dtypes - claimed_supported claimed_but_unsupported = claimed_supported - supported_dtypes msg = """The supported dtypes for {0} on {1} according to its OpInfo are {2}, but the detected supported dtypes are {3}. """.format(op.name, device_type, claimed_supported, supported_dtypes) if len(supported_but_unclaimed) > 0: msg += "The following dtypes should be added to the OpInfo: {0}. ".format( supported_but_unclaimed) if len(claimed_but_unsupported) > 0: msg += "The following dtypes should be removed from the OpInfo: {0}.".format( claimed_but_unsupported) self.assertEqual(supported_dtypes, claimed_supported, msg=msg) # Checks that backward dtypes are listed correctly and generates an # informative error message # NOTE: this code is nearly identical to the check + msg generation claimed_backward_supported = set( op.supported_backward_dtypes(device_type)) supported_backward_dtypes = set(supported_backward_dtypes) supported_but_unclaimed = supported_backward_dtypes - claimed_backward_supported claimed_but_unsupported = claimed_backward_supported - supported_backward_dtypes msg = """The supported backward dtypes for {0} on {1} according to its OpInfo are {2}, but the detected supported backward dtypes are {3}. """.format(op.name, device_type, claimed_backward_supported, supported_backward_dtypes) if len(supported_but_unclaimed) > 0: msg += "The following backward dtypes should be added to the OpInfo: {0}. ".format( supported_but_unclaimed) if len(claimed_but_unsupported) > 0: msg += "The following backward dtypes should be removed from the OpInfo: {0}.".format( claimed_but_unsupported) self.assertEqual(supported_backward_dtypes, claimed_backward_supported, msg=msg)
def instantiate_test(cls, name, test): def instantiate_test_helper(cls, name, *, test, dtype, op): # Constructs the test's name test_name = name if op is not None: test_name += "_" + op.name test_name += "_" + cls.device_type if dtype is not None: if isinstance(dtype, (list, tuple)): for d in dtype: test_name += "_" + str(d).split('.')[1] else: test_name += "_" + str(dtype).split('.')[1] # Constructs the test @wraps(test) def instantiated_test(self, test=test, dtype=dtype, op=op): device_arg = cls.get_primary_device() if hasattr(test, 'num_required_devices'): device_arg = cls.get_all_devices() # Sets precision and runs test # Note: precision is reset after the test is run guard_precision = self.precision try: self.precision = self._get_precision_override(test, dtype) args = (device_arg, dtype, op) args = (arg for arg in args if arg is not None) result = test(self, *args) finally: self.precision = guard_precision return result # wraps with op decorators if op is not None and op.decorators is not None: for decorator in op.decorators: instantiated_test = decorator(instantiated_test) assert not hasattr( cls, test_name), "Redefinition of test {0}".format(test_name) setattr(cls, test_name, instantiated_test) # Handles tests using the ops decorator if hasattr(test, "op_list"): for op in test.op_list: # Acquires dtypes, using the op data if unspecified dtypes = cls._get_dtypes(test) if dtypes is None: if cls.device_type == 'cpu' and op.dtypesIfCPU is not None: dtypes = op.dtypesIfCPU elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM and op.dtypesIfCUDA is not None): dtypes = op.dtypesIfCUDA elif (cls.device_type == 'cuda' and TEST_WITH_ROCM and op.dtypesIfROCM is not None): dtypes = op.dtypesIfROCM else: dtypes = op.dtypes # Inverts dtypes if the function wants unsupported dtypes if test.unsupported_dtypes_only is True: dtypes = [d for d in get_all_dtypes() if d not in dtypes] dtypes = dtypes if dtypes is not None else (None, ) for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=op) else: # Handles tests that don't use the ops decorator dtypes = cls._get_dtypes(test) dtypes = tuple(dtypes) if dtypes is not None else (None, ) for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)