def allSum(vs): if isinstance(vs, torch.Tensor): vs = (vs, ) return sum( (i + 1) * v.sum() for i, v in enumerate(vs) if v is not None and v.dtype in floating_and_complex_types_and( torch.half, torch.bfloat16))
def test_floating_inputs_are_differentiable(self, device, dtype, op): # Nothing to check if the operation it's not differentiable if not op.supports_autograd: return floating_dtypes = list( floating_and_complex_types_and(torch.bfloat16, torch.float16)) def check_tensor_floating_is_differentiable(t): if isinstance(t, torch.Tensor) and t.dtype in floating_dtypes: msg = ( f"Found a sampled tensor of floating-point dtype {t.dtype} sampled with " "requires_grad=False. If this is intended, please skip/xfail this test. " "Remember that sampling operations are executed under a torch.no_grad contextmanager." ) self.assertTrue(t.requires_grad, msg) samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: check_tensor_floating_is_differentiable(sample.input) for arg in sample.args: check_tensor_floating_is_differentiable(arg) for arg in sample.kwargs.values(): check_tensor_floating_is_differentiable(arg)
def test_dtypes(self, device, op): # dtypes to try to backward in allowed_backward_dtypes = floating_and_complex_types_and( torch.bfloat16, torch.float16) # lists for (un)supported dtypes supported_dtypes = [] unsupported_dtypes = [] supported_backward_dtypes = [] unsupported_backward_dtypes = [] def unsupported(dtype): unsupported_dtypes.append(dtype) if dtype in allowed_backward_dtypes: unsupported_backward_dtypes.append(dtype) for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): # tries to acquire samples - failure indicates lack of support requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd) try: samples = list( op.sample_inputs(device, dtype, requires_grad=requires_grad)) except Exception as e: unsupported(dtype) continue # Counts number of successful backward attempts # NOTE: This exists as a kludge because this only understands how to # request a gradient if the output is a tensor or a sequence with # a tensor as its first element. num_backward_successes = 0 for sample in samples: # tries to call operator with the sample - failure indicates # lack of support try: result = op(sample.input, *sample.args, **sample.kwargs) except Exception as e: # NOTE: some ops will fail in forward if their inputs # require grad but they don't support computing the gradient # in that type! This is a bug in the op! unsupported(dtype) # Short-circuits testing this dtype -- it doesn't work if dtype in unsupported_dtypes: break # Short-circuits if the dtype isn't a backward dtype or # it's already identified as not supported if dtype not in allowed_backward_dtypes or dtype in unsupported_backward_dtypes: continue # Checks for backward support in the same dtype try: result = sample.output_process_fn_grad(result) if isinstance(result, torch.Tensor): backward_tensor = result elif isinstance(result, Sequence) and isinstance( result[0], torch.Tensor): backward_tensor = result[0] else: continue # Note: this grad may not have the same dtype as dtype # For functions like complex (float -> complex) or abs # (complex -> float) the grad tensor will have a # different dtype than the input. # For simplicity, this is still modeled as these ops # supporting grad in the input dtype. grad = torch.randn_like(backward_tensor) backward_tensor.backward(grad) num_backward_successes += 1 except Exception as e: unsupported_backward_dtypes.append(dtype) if dtype not in unsupported_dtypes: supported_dtypes.append(dtype) if num_backward_successes > 0 and dtype not in unsupported_backward_dtypes: supported_backward_dtypes.append(dtype) # Checks that dtypes are listed correctly and generates an informative # error message device_type = torch.device(device).type claimed_supported = set(op.supported_dtypes(device_type)) supported_dtypes = set(supported_dtypes) supported_but_unclaimed = supported_dtypes - claimed_supported claimed_but_unsupported = claimed_supported - supported_dtypes msg = """The supported dtypes for {0} on {1} according to its OpInfo are {2}, but the detected supported dtypes are {3}. """.format(op.name, device_type, claimed_supported, supported_dtypes) if len(supported_but_unclaimed) > 0: msg += "The following dtypes should be added to the OpInfo: {0}. ".format( supported_but_unclaimed) if len(claimed_but_unsupported) > 0: msg += "The following dtypes should be removed from the OpInfo: {0}.".format( claimed_but_unsupported) self.assertEqual(supported_dtypes, claimed_supported, msg=msg) # Checks that backward dtypes are listed correctly and generates an # informative error message # NOTE: this code is nearly identical to the check + msg generation claimed_backward_supported = set( op.supported_backward_dtypes(device_type)) supported_backward_dtypes = set(supported_backward_dtypes) supported_but_unclaimed = supported_backward_dtypes - claimed_backward_supported claimed_but_unsupported = claimed_backward_supported - supported_backward_dtypes msg = """The supported backward dtypes for {0} on {1} according to its OpInfo are {2}, but the detected supported backward dtypes are {3}. """.format(op.name, device_type, claimed_backward_supported, supported_backward_dtypes) if len(supported_but_unclaimed) > 0: msg += "The following backward dtypes should be added to the OpInfo: {0}. ".format( supported_but_unclaimed) if len(claimed_but_unsupported) > 0: msg += "The following backward dtypes should be removed from the OpInfo: {0}.".format( claimed_but_unsupported) self.assertEqual(supported_backward_dtypes, claimed_backward_supported, msg=msg)
def test_out(self, device, op): # TODO: verify the op doesn't support the out= kwarg if not op.supports_out: self.skipTest("Skipped! Op doesn't support out= kwarg.") # Prefers running in float32 but has a fallback for the first listed supported dtype supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: self.skipTest( "Skipped! Op has not supported dtypes on this device.") dtype = torch.float32 if torch.float32 in supported_dtypes else list( supported_dtypes)[0] # NOTE: only tests on first sample samples = op.sample_inputs(device, dtype) sample = first_sample(self, samples) # calls it normally to get the expected result expected = op(sample.input, *sample.args, **sample.kwargs) op_out = partial(op, sample.input, *sample.args, **sample.kwargs) # Short-circuits if output is not a single tensor or an # iterable of tensors if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( expected, include_empty=True): self.skipTest( "Skipped! Only supports single tensor or iterable of tensor outputs." ) # A wrapper around map that works with single tensors and always # instantiates the map. Used below to apply transforms to # single tensor and iterable tensor outputs. def _apply_out_transform(fn, out): if isinstance(out, torch.Tensor): return fn(out) # assumes (see above) that out is an iterable of tensors return tuple(map(fn, out)) # Extracts strides from a tensor or iterable of tensors into a tuple def _extract_strides(out): if isinstance(out, torch.Tensor): return (out.stride(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.stride(), out)) # Extracts data pointers from a tensor or iterable of tensors into a tuple # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): if self.device_type != 'cpu' and self.device_type != 'cuda': return () if isinstance(out, torch.Tensor): return (out.data_ptr(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.data_ptr(), out)) def _compare_out(transform, *, compare_strides_and_data_ptrs=True): out = _apply_out_transform(transform, expected) original_strides = _extract_strides(out) original_ptrs = _extract_data_ptrs(out) op_out(out=out) final_strides = _extract_strides(out) final_ptrs = _extract_data_ptrs(out) self.assertEqual(expected, out) if compare_strides_and_data_ptrs: self.assertEqual(original_strides, final_strides) self.assertEqual(original_ptrs, final_ptrs) # Case 0: out= with the correct shape, dtype, and device # but NaN values for floating point and complex tensors, and # maximum values for integer tensors. # Expected behavior: out= values have no effect on the computation. def _case_zero_transform(t): try: info = torch.iinfo(t.dtype) return torch.full_like(t, info.max) except TypeError as te: # for non-integer types fills with NaN return torch.full_like(t, float('nan')) _compare_out(_case_zero_transform) # Case 1: out= with the correct shape, dtype, and device, # but noncontiguous. # Expected behavior: strides are respected and `out` storage is not changed. def _case_one_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=t.device, noncontiguous=True) _compare_out(_case_one_transform) # Case 2: out= with the correct dtype and device, but has no elements. # Expected behavior: resize without warning. def _case_two_transform(t): return make_tensor((0, ), dtype=t.dtype, device=t.device) _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False) # Also validates that no warning is thrown when this out is resized out = _apply_out_transform(_case_two_transform, expected) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") op_out(out=out) # Verifies no warning is a resize warning for w in caught: if "An output with one or more elements" in str(w.message): self.fail( "Resizing an out= argument with no elements threw a resize warning!" ) # Case 3: out= with correct shape and dtype, but wrong device. wrong_device = None if torch.device(device).type != 'cpu': wrong_device = 'cpu' elif torch.cuda.is_available(): wrong_device = 'cuda' if wrong_device is not None: def _case_three_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) out = _apply_out_transform(_case_three_transform, expected) msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}" with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out) # Case 4: out= with correct shape and device, but a dtype # that output cannot be "safely" cast to (long). # Expected behavior: error. # NOTE: this case is filtered by dtype since some ops produce # bool tensors, for example, which can be safely cast to any # dtype. It is applied when single tensors are floating point or complex # dtypes, or if an op returns multiple tensors when at least one such # tensor is a floating point or complex dtype. _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or (not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))): def _case_four_transform(t): return make_tensor(t.shape, dtype=torch.long, device=t.device) out = _apply_out_transform(_case_four_transform, expected) msg_fail = "" if not isinstance(expected, torch.Tensor) else \ ("Expected RuntimeError when doing an unsafe cast from a result of dtype " f"{expected.dtype} into an out= with dtype torch.long") with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out)
unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, supports_forward_ad=True, supports_out=False, ), OpInfo( "_masked.normalize", method_variant=None, dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_masked_normalize, skips=( DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), # RuntimeError: "clamp_min_cpu" not implemented for 'Half' DecorateInfo( unittest.expectedFailure, "TestMasked", "test_reference_masked",
def test_dtypes(self, device, op): # Check complex32 support only if the op claims. # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. device_type = torch.device(device).type include_complex32 = ((torch.complex32, ) if op.supports_dtype( torch.complex32, device_type) else ()) # dtypes to try to backward in allowed_backward_dtypes = floating_and_complex_types_and( *((torch.half, torch.bfloat16) + include_complex32)) # lists for (un)supported dtypes supported_dtypes = set() unsupported_dtypes = set() supported_backward_dtypes = set() unsupported_backward_dtypes = set() def unsupported(dtype): unsupported_dtypes.add(dtype) if dtype in allowed_backward_dtypes: unsupported_backward_dtypes.add(dtype) for dtype in all_types_and_complex_and( *((torch.half, torch.bfloat16, torch.bool) + include_complex32)): # tries to acquire samples - failure indicates lack of support requires_grad = (dtype in allowed_backward_dtypes) try: samples = tuple( op.sample_inputs(device, dtype, requires_grad=requires_grad)) except Exception as e: unsupported(dtype) continue for sample in samples: # tries to call operator with the sample - failure indicates # lack of support try: result = op(sample.input, *sample.args, **sample.kwargs) supported_dtypes.add(dtype) except Exception as e: # NOTE: some ops will fail in forward if their inputs # require grad but they don't support computing the gradient # in that type! This is a bug in the op! unsupported(dtype) continue # Checks for backward support in the same dtype try: result = sample.output_process_fn_grad(result) if isinstance(result, torch.Tensor): backward_tensor = result elif isinstance(result, Sequence) and isinstance( result[0], torch.Tensor): backward_tensor = result[0] else: continue # Note: this grad may not have the same dtype as dtype # For functions like complex (float -> complex) or abs # (complex -> float) the grad tensor will have a # different dtype than the input. # For simplicity, this is still modeled as these ops # supporting grad in the input dtype. grad = torch.randn_like(backward_tensor) backward_tensor.backward(grad) supported_backward_dtypes.add(dtype) except Exception as e: unsupported_backward_dtypes.add(dtype) # Checks that dtypes are listed correctly and generates an informative # error message supported_forward = supported_dtypes - unsupported_dtypes partially_supported_forward = supported_dtypes & unsupported_dtypes unsupported_forward = unsupported_dtypes - supported_dtypes supported_backward = supported_backward_dtypes - unsupported_backward_dtypes partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes device_type = torch.device(device).type claimed_forward = set(op.supported_dtypes(device_type)) supported_but_unclaimed_forward = supported_forward - claimed_forward claimed_but_unsupported_forward = claimed_forward & unsupported_forward claimed_backward = set(op.supported_backward_dtypes(device_type)) supported_but_unclaimed_backward = supported_backward - claimed_backward claimed_but_unsupported_backward = claimed_backward & unsupported_backward # Partially supporting a dtype is not an error, but we print a warning if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: msg = "Some dtypes for {0} on device type {1} are only partially supported!\n".format( op.name, device_type) if len(partially_supported_forward) > 0: msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format( partially_supported_forward) if len(partially_supported_backward) > 0: msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format( partially_supported_backward) print(msg) if (len(supported_but_unclaimed_forward) + len(claimed_but_unsupported_forward) + len(supported_but_unclaimed_backward) + len(claimed_but_unsupported_backward)) == 0: return # Generates error msg msg = "The supported dtypes for {0} on device type {1} are incorrect!\n".format( op.name, device_type) if len(supported_but_unclaimed_forward) > 0: msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format( supported_but_unclaimed_forward) if len(supported_but_unclaimed_backward) > 0: msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format( supported_but_unclaimed_backward) if len(claimed_but_unsupported_forward) > 0: msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format( claimed_but_unsupported_forward) if len(claimed_but_unsupported_backward) > 0: msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format( claimed_but_unsupported_backward) self.fail(msg)