def partial_apply_nontensors(fn, args, **kwargs): source = ['t' if (isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)) else 's' for arg in args] def new_fn(*tensors_): tensors = iter(tensors_) return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs) return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)]
def test_supported_dtypes(self, device, dtype, op): for sample in op.sample_inputs(device, dtype): op(sample.input, *sample.args, **sample.kwargs) # NOTE: only check the first tensor in the iterable of tensors sample_input = sample.input[0] if is_iterable_of_tensors(sample.input) else sample.input self.assertTrue(sample_input.dtype == dtype) self.assertTrue(sample_input.device.type == self.device_type)
def fn(*inputs): # Pack input back into TensorList since we splat it when passing to gradcheck if is_iterable_of_tensors(sample.input): n = len(sample.input) inputs = (inputs[:n], *inputs[n:]) output = variant_out_fn(*inputs, **sample.kwargs) return op.output_func(output)
def fn(*inputs): # Put tensors back into TensorList since we splat them when passing to gradcheck if is_iterable_of_tensors(sample.input): n = len(sample.input) inputs = (inputs[:n], *inputs[n:]) output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs) if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output
def get_recording_tensors(args): recording_tensors: List[torch.Tensor] = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.requires_grad: recording_tensors.append(arg) elif is_iterable_of_tensors(arg): recording_tensors.extend(filter(lambda t: t.requires_grad, arg)) return recording_tensors
def clone_inputs(args): inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = [] for arg in args: if isinstance(arg, torch.Tensor): inputs.append(arg.detach().clone()) elif is_iterable_of_tensors(arg): inputs.append([t.detach().clone() for t in arg]) else: inputs.append(arg) return inputs
def clone_inputs(preserve_requires_grad: bool): inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = [] for arg in args: if isinstance(arg, torch.Tensor): inputs.append(clone_tensor(arg, preserve_requires_grad)) elif is_iterable_of_tensors(arg): inputs.append( [clone_tensor(t, preserve_requires_grad) for t in arg]) else: inputs.append(arg) return inputs
def _input_recomposition_helper(inputs, inp, input_idx): if is_iterable_of_tensors(inp): tensor_list = [] for x in inp: if isinstance(x, torch.Tensor) and x.requires_grad: tensor_list.append(inputs[input_idx]) input_idx = input_idx + 1 else: tensor_list.append(x) return tensor_list, input_idx elif isinstance(inp, torch.Tensor) and inp.requires_grad: return inputs[input_idx], input_idx + 1 else: return inp, input_idx
def diff_arg(arg, requires_grad=True): def is_differentiable_arg(arg): if requires_grad: return arg.requires_grad else: return arg.is_floating_point() or arg.is_complex() if is_iterable_of_tensors(arg): if all([is_differentiable_arg(a) for a in arg]): return True if all([not is_differentiable_arg(a) for a in arg]): return False raise RuntimeError("NYI: The test runner can't handle this") return isinstance(arg, Tensor) and is_differentiable_arg(arg)
def test_multiple_devices(self, devices, dtype, op): for cuda_device_str in devices: cuda_device = torch.device(cuda_device_str) # NOTE: only tests on first sample samples = op.sample_inputs(cuda_device, dtype) sample = samples[0] result = op(sample.input, *sample.args, **sample.kwargs) if isinstance(result, torch.Tensor): self.assertTrue(result.device == cuda_device) elif is_iterable_of_tensors(result): self.assertTrue(all(map(lambda t: t.device == cuda_device, result))) else: self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
def get_script_args(args): formals: List[str] = [] tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = [] actuals: List[str] = [] for arg in args: if isinstance(arg, torch.Tensor): name = 'i{}'.format(len(formals)) formals.append(name) actuals.append(name) tensors.append(arg) elif is_iterable_of_tensors(arg): name = 'i{}'.format(len(formals)) formals.append(name + ': List[torch.Tensor]') actuals.append(name) tensors.append(list(arg)) elif isinstance(arg, str): actuals.append("'{}'".format(arg)) else: actuals.append(str(get_constant(arg))) return (formals, tensors, actuals)
def test_out(self, device, dtype, op): # TODO: verify the op doesn't support the out= kwarg if not op.supports_out: self.skipTest("Skipped! Op doesn't support out= kwarg.") # NOTE: only tests on first sample samples = op.sample_inputs(device, dtype) sample = samples[0] # calls it normally to get the expected result expected = op(sample.input, *sample.args, **sample.kwargs) op_out = partial(op, sample.input, *sample.args, **sample.kwargs) # Short-circuits if output is not a single tensor or an # iterable of tensors if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( expected, include_empty=True): self.skipTest( "Skipped! Only supports single tensor or iterable of tensor outputs." ) # A wrapper around map that works with single tensors and always # instantiates the map. Used below to apply transforms to # single tensor and iterable tensor outputs. def _apply_out_transform(fn, out): if isinstance(out, torch.Tensor): return fn(out) # assumes (see above) that out is an iterable of tensors return tuple(map(fn, out)) # Case 0: out= with the correct shape, dtype, and device # but NaN values for floating point and complex tensors, and # maximum values for integer tensors. # Expected behavior: out= values have no effect on the computation. def _case_zero_transform(t): try: info = torch.iinfo(t.dtype) return torch.full_like(t, info.max) except TypeError as te: # for non-integer types fills with NaN return torch.full_like(t, float('nan')) out = _apply_out_transform(_case_zero_transform, expected) result = op_out(out=out) self.assertEqual(expected, out) # Checks that the returned value shares storage with out # NOTE: only checks on the CPU and CUDA device types since some # device types don't have storage if self.device_type == 'cpu' or self.device_type == 'cuda': if isinstance(out, torch.Tensor): self.assertEqual(out.storage().data_ptr(), result.storage().data_ptr()) else: for out_t, result_t in zip(out, result): self.assertEqual(out_t.storage().data_ptr(), result_t.storage().data_ptr()) # Case 1: out= with the correct shape, dtype, and device, # but noncontiguous. # Expected behavior: strides are respected and `out` storage is not changed. def _case_one_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=t.device, noncontiguous=True) # Extracts strides from a tensor or iterable of tensors into a tuple def _extract_strides(out): if isinstance(out, torch.Tensor): return (out.stride(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.stride(), out)) def _extract_data_ptrs(out): if isinstance(out, torch.Tensor): return (out.data_ptr(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.data_ptr(), out)) out = _apply_out_transform(_case_one_transform, expected) original_strides = _extract_strides(out) original_ptrs = _extract_data_ptrs(out) op_out(out=out) final_strides = _extract_strides(out) final_ptrs = _extract_data_ptrs(out) self.assertEqual(expected, out) self.assertEqual(original_strides, final_strides) self.assertEqual(original_ptrs, final_ptrs) # Case 2: out= with the correct dtype and device, but the wrong shape # Expected behavior: resize with a warning. def _case_two_transform(t): wrong_shape = list(t.shape) if len(wrong_shape) == 0: # Handles scalar tensor case (empty list) wrong_shape = [2] else: wrong_shape[-1] = wrong_shape[-1] + 1 return make_tensor(wrong_shape, dtype=t.dtype, device=t.device) out = _apply_out_transform(_case_two_transform, expected) msg_fail = "Resized a non-empty tensor but did not warn about it." with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail): op_out(out=out) self.assertEqual(expected, out) # Case 3: out= with the correct dtype and device, but an empty # tensor. # Expected behavior: resize without warning. def _case_three_transform(t): return make_tensor((0, ), dtype=t.dtype, device=t.device) out = _apply_out_transform(_case_three_transform, expected) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") op_out(out=out) # Verifies no warning is a resize warning for w in caught: if "An output with one or more elements" in str(w.message): self.fail( "Resizing an out= argument with no elements threw a resize warning!" ) self.assertEqual(expected, out) # Case 4: out= with correct shape and dtype, but wrong device. wrong_device = None if torch.device(device).type != 'cpu': wrong_device = 'cpu' elif torch.cuda.is_available(): wrong_device = 'cuda' if wrong_device is not None: def _case_four_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) out = _apply_out_transform(_case_four_transform, expected) msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}" with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out) # Case 5: out= with correct shape and device, but a dtype # that output cannot be "safely" cast to (long). # Expected behavior: error. # NOTE: this case is filtered by dtype since some ops produce # bool tensors, for example, which can be safely cast to any # dtype. It is applied when single tensors are floating point or complex # dtypes, or if an op returns multiple tensors when at least one such # tensor is a floating point or complex dtype. _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or (not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))): def _case_five_transform(t): return make_tensor(t.shape, dtype=torch.long, device=t.device) out = _apply_out_transform(_case_five_transform, expected) msg_fail = "" if not isinstance(expected, torch.Tensor) else \ ("Expected RuntimeError when doing an unsafe cast from a result of dtype " f"{expected.dtype} into an out= with dtype torch.long") with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out)
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True, check_batched_grad=None, check_batched_forward_grad=False): assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad') # NB: check_backward_ad does not affect gradgradcheck (always True) if variant is None: self.skipTest("Skipped! Variant not implemented.") if not op.supports_dtype(dtype, torch.device(device).type): self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") def is_inplace(variant): if hasattr(variant, "__wrapped__"): return variant.__wrapped__ is op.get_inplace() return variant is op.get_inplace() include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs, small_inputs_only=is_slow_gradcheck_env()) for sample in samples: if sample.broadcasts_input and is_inplace(variant): continue # Gradcheck expects tensors as its input, but autograd actually supports tensorlists # and tensors passed as kwargs. The following creates a function that accepts just # the tensors that require grad as varargs, and then recomposes them back into the # original input. # Creates gradcheck inputs by identifying tensors requiring grad all_args = None if is_iterable_of_tensors(sample.input): all_args = chain(sample.input, sample.args, sample.kwargs.values()) else: all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) def _input_recomposition_helper(inputs, inp, input_idx): if is_iterable_of_tensors(inp): tensor_list = [] for x in inp: if isinstance(x, torch.Tensor) and x.requires_grad: tensor_list.append(inputs[input_idx]) input_idx = input_idx + 1 else: tensor_list.append(x) return tensor_list, input_idx elif isinstance(inp, torch.Tensor) and inp.requires_grad: return inputs[input_idx], input_idx + 1 else: return inp, input_idx def fn(*inputs): # Puts inputs back into sample properly positional_args = [] input_idx = 0 inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx) positional_args.append(inp) for x in sample.args: inp, input_idx = _input_recomposition_helper(inputs, x, input_idx) positional_args.append(inp) # Recreates kwargs kwargs = {} for k, v in sample.kwargs.items(): inp, input_idx = _input_recomposition_helper(inputs, v, input_idx) kwargs[k] = inp output = op.gradcheck_wrapper(variant, *positional_args, **kwargs) if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output if check == 'gradcheck': if check_batched_grad is None: check_batched_grad = op.check_batched_grad self.assertTrue(gradcheck(fn, gradcheck_args, check_batched_grad=check_batched_grad, check_grad_dtypes=True, nondet_tol=op.gradcheck_nondet_tol, fast_mode=op.gradcheck_fast_mode, check_forward_ad=check_forward_ad, check_backward_ad=check_backward_ad, check_undefined_grad=True, check_batched_forward_grad=check_batched_forward_grad)) elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") for gen_non_contig_grad_outputs in (False, True): kwargs = { "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs, "check_batched_grad": op.check_batched_gradgrad, "check_grad_dtypes": True, "nondet_tol": op.gradcheck_nondet_tol, "fast_mode": op.gradcheck_fast_mode } if check == "fwgrad_bwgrad": kwargs["check_fwd_over_rev"] = True kwargs["check_rev_over_rev"] = False kwargs["check_batched_grad"] = False kwargs["check_undefined_grad"] = False self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs)) else: self.assertTrue(False, msg="Unknown check requested!")
def test_out_warning(self, device, op): # TODO: verify the op doesn't support the out= kwarg if not op.supports_out: self.skipTest("Skipped! Op doesn't support out= kwarg.") # Prefers running in float32 but has a fallback for the first listed supported dtype supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: self.skipTest( "Skipped! Op has not supported dtypes on this device.") dtype = torch.float32 if torch.float32 in supported_dtypes else list( supported_dtypes)[0] # NOTE: only tests on first sample samples = op.sample_inputs(device, dtype) sample = first_sample(self, samples) # calls it normally to get the expected result expected = op(sample.input, *sample.args, **sample.kwargs) op_out = partial(op, sample.input, *sample.args, **sample.kwargs) # Short-circuits if output is not a single tensor or an # iterable of tensors if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( expected, include_empty=True): self.skipTest( "Skipped! Only supports single tensor or iterable of tensor outputs." ) # A wrapper around map that works with single tensors and always # instantiates the map. Used below to apply transforms to # single tensor and iterable tensor outputs. def _apply_out_transform(fn, out): if isinstance(out, torch.Tensor): return fn(out) # assumes (see above) that out is an iterable of tensors return tuple(map(fn, out)) # Extracts strides from a tensor or iterable of tensors into a tuple def _extract_strides(out): if isinstance(out, torch.Tensor): return (out.stride(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.stride(), out)) # Extracts data pointers from a tensor or iterable of tensors into a tuple # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): if self.device_type != 'cpu' and self.device_type != 'cuda': return () if isinstance(out, torch.Tensor): return (out.data_ptr(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.data_ptr(), out)) def _compare_out(transform, *, compare_strides_and_data_ptrs=True): out = _apply_out_transform(transform, expected) original_strides = _extract_strides(out) original_ptrs = _extract_data_ptrs(out) op_out(out=out) final_strides = _extract_strides(out) final_ptrs = _extract_data_ptrs(out) self.assertEqual(expected, out) if compare_strides_and_data_ptrs: self.assertEqual(original_strides, final_strides) self.assertEqual(original_ptrs, final_ptrs) # Case: out= with the correct dtype and device, but the wrong shape # Expected behavior: resize with a warning. def _case_two_transform(t): wrong_shape = list(t.shape) if len(wrong_shape) == 0: # Handles scalar tensor case (empty list) wrong_shape = [2] else: wrong_shape[-1] = wrong_shape[-1] + 1 return make_tensor(wrong_shape, dtype=t.dtype, device=t.device) _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False) # Additional validates that the appropriate warning is thrown out = _apply_out_transform(_case_two_transform, expected) msg_fail = "Resized a non-empty tensor but did not warn about it." with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail): op_out(out=out)
def _is_tensor_input(arg): return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
def test_out(self, device, op): # Prefers running in float32 but has a fallback for the first listed supported dtype supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: self.skipTest( "Skipped! Op has not supported dtypes on this device.") dtype = torch.float32 if torch.float32 in supported_dtypes else list( supported_dtypes)[0] samples = op.sample_inputs(device, dtype) for sample in samples: # calls it normally to get the expected result expected = op(sample.input, *sample.args, **sample.kwargs) op_out = partial(op, sample.input, *sample.args, **sample.kwargs) # Short-circuits if output is not a single tensor or an # iterable of tensors if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( expected, include_empty=True): self.skipTest( "Skipped! Only supports single tensor or iterable of tensor outputs." ) # Validates the op doesn't support out if it claims not to if not op.supports_out: with self.assertRaises(Exception): assert op_out(out=expected) != NotImplemented return # A wrapper around map that works with single tensors and always # instantiates the map. Used below to apply transforms to # single tensor and iterable tensor outputs. def _apply_out_transform(fn, out): if isinstance(out, torch.Tensor): return fn(out) # assumes (see above) that out is an iterable of tensors return tuple(map(fn, out)) # Extracts strides from a tensor or iterable of tensors into a tuple def _extract_strides(out): if isinstance(out, torch.Tensor): return (out.stride(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.stride(), out)) # Extracts data pointers from a tensor or iterable of tensors into a tuple # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): if self.device_type != 'cpu' and self.device_type != 'cuda': return () if isinstance(out, torch.Tensor): return (out.data_ptr(), ) # assumes (see above) that out is an iterable of tensors return tuple(map(lambda t: t.data_ptr(), out)) def _compare_out(transform, *, compare_strides_and_data_ptrs=True): out = _apply_out_transform(transform, expected) original_strides = _extract_strides(out) original_ptrs = _extract_data_ptrs(out) op_out(out=out) final_strides = _extract_strides(out) final_ptrs = _extract_data_ptrs(out) self.assertEqual(expected, out) if compare_strides_and_data_ptrs: stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format( original_strides, final_strides) self.assertEqual(original_strides, final_strides, msg=stride_msg) self.assertEqual(original_ptrs, final_ptrs) # Case 0: out= with the correct shape, dtype, and device # but NaN values for floating point and complex tensors, and # maximum values for integer tensors. # Expected behavior: out= values have no effect on the computation. def _case_zero_transform(t): try: info = torch.iinfo(t.dtype) return torch.full_like(t, info.max) except TypeError as te: # for non-integer types fills with NaN return torch.full_like(t, float('nan')) _compare_out(_case_zero_transform) # Case 1: out= with the correct shape, dtype, and device, # but noncontiguous. # Expected behavior: strides are respected and `out` storage is not changed. def _case_one_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=t.device, noncontiguous=True) _compare_out(_case_one_transform) # Case 2: out= with the correct dtype and device, but has no elements. # Expected behavior: resize without warning. def _case_two_transform(t): return make_tensor((0, ), dtype=t.dtype, device=t.device) _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False) # Also validates that no warning is thrown when this out is resized out = _apply_out_transform(_case_two_transform, expected) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") op_out(out=out) # Verifies no warning is a resize warning for w in caught: if "An output with one or more elements" in str(w.message): self.fail( "Resizing an out= argument with no elements threw a resize warning!" ) # Case 3: out= with correct shape and dtype, but wrong device. wrong_device = None if torch.device(device).type != 'cpu': wrong_device = 'cpu' elif torch.cuda.is_available(): wrong_device = 'cuda' if wrong_device is not None: def _case_three_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) out = _apply_out_transform(_case_three_transform, expected) msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}" with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out) # Case 4: out= with correct shape and device, but a dtype # that output cannot be "safely" cast to (long). # Expected behavior: error. # NOTE: this case is filtered by dtype since some ops produce # bool tensors, for example, which can be safely cast to any # dtype. It is applied when single tensors are floating point or complex # dtypes, or if an op returns multiple tensors when at least one such # tensor is a floating point or complex dtype. _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or (not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))): def _case_four_transform(t): return make_tensor(t.shape, dtype=torch.long, device=t.device) out = _apply_out_transform(_case_four_transform, expected) msg_fail = "Expected RuntimeError when doing an unsafe cast!" msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \ ("Expected RuntimeError when doing an unsafe cast from a result of dtype " f"{expected.dtype} into an out= with dtype torch.long") with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out)