def test_out(self, device, dtype, op): # TODO: verify the op doesn't support the out= kwarg if not op.supports_out: self.skipTest("Skipped! Op doesn't support out= kwarg.") # NOTE: only tests on first sample samples = op.sample_inputs(device, dtype) sample = samples[0] # calls it normally to get the expected result expected = op(sample.input, *sample.args, **sample.kwargs) op_out = partial(op, sample.input, *sample.args, **sample.kwargs) # Short-circuits if output is not a single tensor or an # iterable of tensors # Returns True if iterable is an iterable of tensors (includes empty iterables) # and False o.w. def _is_iterable_of_tensors(iterable): try: for t in iter(iterable): if not isinstance(t, torch.Tensor): return False except TypeError as te: return False return True if not isinstance( expected, torch.Tensor) and not _is_iterable_of_tensors(expected): self.skipTest( "Skipped! Only supports single tensor or iterable of tensor outputs." ) # A wrapper around map that works with single tensors and always # instantiates the map. Used below to apply transforms to # single tensor and iterable tensor outputs. def _apply_out_transform(fn, out): if isinstance(out, torch.Tensor): return fn(out) # assumes (see above) that out is an iterable of tensors return tuple(map(fn, out)) # Case 0: out= with the correct shape, dtype, and device # but NaN values for floating point and complex tensors, and # maximum values for integer tensors. # Expected behavior: out= values have no effect on the computation. def _case_zero_transform(t): try: info = torch.iinfo(t.dtype) return torch.full_like(t, info.max) except TypeError as te: # for non-integer types fills with NaN return torch.full_like(t, float('nan')) out = _apply_out_transform(_case_zero_transform, expected) result = op_out(out=out) self.assertEqual(expected, out) # Checks that the returned value shares storage with out # NOTE: only checks on the CPU and CUDA device types since some # device types don't have storage if self.device_type == 'cpu' or self.device_type == 'cuda': if isinstance(out, torch.Tensor): self.assertEqual(out.storage().data_ptr(), result.storage().data_ptr()) else: for out_t, result_t in zip(out, result): self.assertEqual(out_t.storage().data_ptr(), result_t.storage().data_ptr()) # Case 1: out= with the correct shape, dtype, and device, # but noncontiguous. # Expected behavior: strides are respected. def _case_one_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=t.device, discontiguous=True) # Extracts strides from a tensor or iterable of tensors into a tuple def _extract_strides(out): if isinstance(out, torch.Tensor): return (out.stride(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.stride(), out)) out = _apply_out_transform(_case_one_transform, expected) original_strides = _extract_strides(out) op_out(out=out) final_strides = _extract_strides(out) self.assertEqual(expected, out) self.assertEqual(original_strides, final_strides) # Case 2: out= with the correct dtype and device, but the wrong shape # Expected behavior: resize with a warning. def _case_two_transform(t): wrong_shape = list(t.shape) if len(wrong_shape) == 0: # Handles scalar tensor case (empty list) wrong_shape = [2] else: wrong_shape[-1] = wrong_shape[-1] + 1 return make_tensor(wrong_shape, dtype=t.dtype, device=t.device) out = _apply_out_transform(_case_two_transform, expected) with self.assertWarnsRegex(UserWarning, "An output with one or more elements"): op_out(out=out) self.assertEqual(expected, out) # Case 3: out= with the correct dtype and device, but an empty # tensor. # Expected behavior: resize without warning. def _case_three_transform(t): return make_tensor((0, ), dtype=t.dtype, device=t.device) out = _apply_out_transform(_case_three_transform, expected) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") op_out(out=out) # Verifies no warning is a resize warning for w in caught: if "An output with one or more elements" in str(w.message): self.fail( "Resizing an out= argument with no elements threw a resize warning!" ) self.assertEqual(expected, out) # Case 4: out= with correct shape and dtype, but wrong device. wrong_device = None if torch.device(device).type != 'cpu': wrong_device = 'cpu' elif torch.cuda.is_available(): wrong_device = 'cuda' if wrong_device is not None: def _case_four_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) out = _apply_out_transform(_case_four_transform, expected) with self.assertRaises(RuntimeError): op_out(out=out) # Case 5: out= with correct shape and device, but a dtype # that output cannot be "safely" cast to (long). # Expected behavior: error. # NOTE: this case is filtered by dtype since some ops produce # bool tensors, for example, which can be safely cast to any # dtype. It is applied when single tensors are floating point or complex # dtypes, or if an op returns multiple tensors when at least one such # tensor is a floating point or complex dtype. _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or (not isinstance(expected, torch.Tensor) and reduce( lambda cur, t: cur or t.dtype in _dtypes, expected, False))): def _case_five_transform(t): return make_tensor(t.shape, dtype=torch.long, device=t.device) out = out = _apply_out_transform(_case_five_transform, expected) with self.assertRaises(RuntimeError): op_out(out=out)
def test_dtypes(self, device, dtype, op): # dtypes to try to backward in allowed_backward_dtypes = floating_and_complex_types_and( torch.bfloat16, torch.float16) # lists for (un)supported dtypes supported_dtypes = [] unsupported_dtypes = [] supported_backward_dtypes = [] unsupported_backward_dtypes = [] def unsupported(dtype): unsupported_dtypes.append(dtype) if dtype in allowed_backward_dtypes: unsupported_backward_dtypes.append(dtype) for dtype in get_all_dtypes(): # tries to acquire samples - failure indicates lack of support requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd) try: samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) except Exception as e: unsupported(dtype) continue # Counts number of successful backward attempts # NOTE: This exists as a kludge because this only understands how to # request a gradient if the output is a tensor or a sequence with # a tensor as its first element. num_backward_successes = 0 for sample in samples: # tries to call operator with the sample - failure indicates # lack of support try: result = op(sample.input, *sample.args, **sample.kwargs) except Exception as e: # NOTE: some ops will fail in forward if their inputs # require grad but they don't support computing the gradient # in that type! This is a bug in the op! unsupported(dtype) # Short-circuits testing this dtype -- it doesn't work if dtype in unsupported_dtypes: break # Short-circuits if the dtype isn't a backward dtype or # it's already identified as not supported if dtype not in allowed_backward_dtypes or dtype in unsupported_backward_dtypes: continue # Checks for backward support in the same dtype try: result = sample.output_process_fn_grad(result) if isinstance(result, torch.Tensor): backward_tensor = result elif isinstance(result, Sequence) and isinstance( result[0], torch.Tensor): backward_tensor = result[0] else: continue # Note: this grad may not have the same dtype as dtype # For functions like complex (float -> complex) or abs # (complex -> float) the grad tensor will have a # different dtype than the input. # For simplicity, this is still modeled as these ops # supporting grad in the input dtype. grad = torch.randn_like(backward_tensor) backward_tensor.backward(grad) num_backward_successes += 1 except Exception as e: unsupported_backward_dtypes.append(dtype) if dtype not in unsupported_dtypes: supported_dtypes.append(dtype) if num_backward_successes > 0 and dtype not in unsupported_backward_dtypes: supported_backward_dtypes.append(dtype) # Checks that dtypes are listed correctly and generates an informative # error message device_type = torch.device(device).type claimed_supported = set(op.supported_dtypes(device_type)) supported_dtypes = set(supported_dtypes) supported_but_unclaimed = supported_dtypes - claimed_supported claimed_but_unsupported = claimed_supported - supported_dtypes msg = """The supported dtypes for {0} on {1} according to its OpInfo are {2}, but the detected supported dtypes are {3}. """.format(op.name, device_type, claimed_supported, supported_dtypes) if len(supported_but_unclaimed) > 0: msg += "The following dtypes should be added to the OpInfo: {0}. ".format( supported_but_unclaimed) if len(claimed_but_unsupported) > 0: msg += "The following dtypes should be removed from the OpInfo: {0}.".format( claimed_but_unsupported) self.assertEqual(supported_dtypes, claimed_supported, msg=msg) # Checks that backward dtypes are listed correctly and generates an # informative error message # NOTE: this code is nearly identical to the check + msg generation claimed_backward_supported = set( op.supported_backward_dtypes(device_type)) supported_backward_dtypes = set(supported_backward_dtypes) supported_but_unclaimed = supported_backward_dtypes - claimed_backward_supported claimed_but_unsupported = claimed_backward_supported - supported_backward_dtypes msg = """The supported backward dtypes for {0} on {1} according to its OpInfo are {2}, but the detected supported backward dtypes are {3}. """.format(op.name, device_type, claimed_backward_supported, supported_backward_dtypes) if len(supported_but_unclaimed) > 0: msg += "The following backward dtypes should be added to the OpInfo: {0}. ".format( supported_but_unclaimed) if len(claimed_but_unsupported) > 0: msg += "The following backward dtypes should be removed from the OpInfo: {0}.".format( claimed_but_unsupported) self.assertEqual(supported_backward_dtypes, claimed_backward_supported, msg=msg)
def test_variant_consistency_jit(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") for sample in samples: # Acquires variants to test func = op.get_op() method = op.get_method() inplace = op.get_inplace() variants = { 'function': func, 'method': method, # TODO: inplace tests currently fail # 'inplace': inplace, } # Test traced and scripted consistency for func_type, variant in variants.items(): if variant is None: continue # Create accessor for script function variant name = op.name + '_' if func_type == 'inplace' else op.name # run with disable_autodiff_subgraph_inlining(True) to test # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): def fn(*inputs, **kwargs): output = func(*inputs, **kwargs) return op.output_func(output) # bfloat16 grad doesn't work for some operators dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \ if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) # Check scripted forward, grad, and grad grad script_fn = create_script_fn(self, name, func_type, op.output_func) check_against_reference(self, script_fn, fn, (*sample.input,) + sample.args, sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, fn, (*sample.input,) + sample.args, sample.kwargs, no_grad=(dtype not in dtypes_to_grad_check)) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) # Note: only runs in float32 and int64 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype in [torch.float32, torch.int32]: check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes if IS_SANDCASTLE: # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes fusible_nodes = [] else: nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
class TestOpInfo(TestCase): exact_dtype = True # Verifies that ops have their unsupported dtypes # registered correctly by testing that each claimed unsupported dtype # throws a runtime error @skipCUDAIfRocm @onlyOnCPUAndCUDA @ops(op_db, dtypes=OpDTypes.unsupported) def test_unsupported_dtypes(self, device, dtype, op): # sample_inputs can have a function for generating the input that doesn't work for specified dtype # https://github.com/pytorch/pytorch/issues/49024 with self.assertRaises(RuntimeError): samples = op.sample_inputs(device, dtype) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") # NOTE: only tests on first sample sample = samples[0] op(sample.input, *sample.args, **sample.kwargs) # Verifies that ops have their supported dtypes # registered correctly by testing that each claimed supported dtype # does NOT throw a runtime error # In addition verifies that the generated sample_inputs have the requested device and dtype @onlyOnCPUAndCUDA @ops(op_db, dtypes=OpDTypes.supported) def test_supported_dtypes(self, device, dtype, op): for sample in op.sample_inputs(device, dtype): op(sample.input, *sample.args, **sample.kwargs) # NOTE: only check the first tensor in the iterable of tensors sample_input = sample.input[0] if is_iterable_of_tensors( sample.input) else sample.input self.assertTrue(sample_input.dtype == dtype) self.assertTrue(sample_input.device.type == self.device_type) # Verifies that backward for each supported floating or complex dtype # does NOT throw a runtime error. # TODO: support multi-tensor outputs @onlyOnCPUAndCUDA @ops(op_db, allowed_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16)) def test_supported_backward(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") if not op.supports_complex_autograd and dtype.is_complex: self.skipTest("Skipped! Complex autograd not supported.") for sample in op.sample_inputs(device, dtype, requires_grad=True): result = op(sample.input, *sample.args, **sample.kwargs) if not isinstance(result, torch.Tensor): continue result.sum().backward() # Verifies that ops do not have an entry in # `method_tests` (legacy testing infra). @onlyCPU @ops(op_db, allowed_dtypes=[torch.float32]) def test_duplicate_method_tests(self, device, dtype, op): self.assertFalse(op.name in method_tested_operators)