def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False): if variant is None: self.skipTest("Skipped! Variant not implemented.") if not op.supports_dtype(dtype, torch.device(device).type): self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") def is_inplace(variant): if hasattr(variant, "__wrapped__"): return variant.__wrapped__ is op.get_inplace() return variant is op.get_inplace() include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs) for sample in samples: if sample.broadcasts_input and is_inplace(variant): continue # Note on TensorList inputs # # gradcheck does not support TensorList inputs so here we pass TensorList # inputs of size n as n single Tensor inputs to gradcheck and wrap the op # in a function that puts the n Tensor inputs back into a TensorList def fn(*inputs): # Put tensors back into TensorList since we splat them when passing to gradcheck if is_iterable_of_tensors(sample.input): n = len(sample.input) inputs = (inputs[:n], *inputs[n:]) output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs) if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output # Splat TensorList inputs into single Tensor inputs gradcheck_args = (sample.input,) if isinstance(sample.input, torch.Tensor) else tuple(sample.input) gradcheck_args += sample.args if check == 'gradcheck': self.assertTrue(gradcheck(fn, gradcheck_args, check_batched_grad=op.check_batched_grad, check_grad_dtypes=True, nondet_tol=op.gradcheck_nondet_tol, fast_mode=op.gradcheck_fast_mode, check_forward_ad=check_forward_ad)) elif check == 'gradgradcheck': self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") self.assertTrue(gradgradcheck(fn, gradcheck_args, gen_non_contig_grad_outputs=False, check_batched_grad=op.check_batched_gradgrad, check_grad_dtypes=True, nondet_tol=op.gradcheck_nondet_tol, fast_mode=op.gradcheck_fast_mode)) self.assertTrue(gradgradcheck(fn, gradcheck_args, gen_non_contig_grad_outputs=True, check_batched_grad=op.check_batched_gradgrad, check_grad_dtypes=True, nondet_tol=op.gradcheck_nondet_tol, fast_mode=op.gradcheck_fast_mode)) else: self.assertTrue(False, msg="Unknown check requested!")
def run_test(fast_mode): a = wrap(torch.tensor(5.0, dtype=torch.double)) b = wrap(torch.tensor(6.0, dtype=torch.double)) a.requires_grad = True b.requires_grad = True gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode) gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode) total_used_attrs = a.used_attrs.union(b.used_attrs) total_used_calls = a.used_calls.union(b.used_calls) # These attributes (and the functions below) may change # if the gradcheck implementation changes. It's best to # aim for attributes that may be commonly present on other # Tensor-likes. expected_used_attrs = { 'data', 'dtype', 'is_floating_point', 'is_sparse', 'is_sparse_csr', 'layout', 'new_zeros', 'numel', 'requires_grad', 'requires_grad_', 'retain_grad', 'size', 'stride', } if fast_mode: expected_used_attrs.add('is_complex') expected_used_attrs.add('device') self.assertEqual(expected_used_attrs, total_used_attrs) expected_used_calls = { torch.Tensor.new_zeros, torch.Tensor.size, torch.Tensor.is_floating_point, torch.Tensor.numel, torch.Tensor.retain_grad, torch.Tensor.stride, torch.Tensor.requires_grad_, torch.autograd.grad, torch.add, } if fast_mode: expected_used_calls.add(torch.Tensor.is_complex) self.assertEqual(expected_used_calls, total_used_calls)
def test_gradcheck(self): from torch.testing._internal.common_utils import gradcheck, gradgradcheck a = wrap(torch.tensor(5.0, dtype=torch.double)) b = wrap(torch.tensor(6.0, dtype=torch.double)) a.requires_grad = True b.requires_grad = True gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False) gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False) total_used_attrs = a.used_attrs.union(b.used_attrs) total_used_calls = a.used_calls.union(b.used_calls) # These attributes (and the functions below) may change # if the gradcheck implementation changes. It's best to # aim for attributes that may be commonly present on other # Tensor-likes. self.assertEqual( total_used_attrs, { 'data', 'dtype', 'is_complex', 'is_floating_point', 'is_sparse', 'layout', 'nelement', 'new_zeros', 'requires_grad', 'retain_grad', 'size', 'stride', }) self.assertEqual( total_used_calls, { torch.Tensor.new_zeros, torch.Tensor.size, torch.Tensor.is_complex, torch.Tensor.is_floating_point, torch.Tensor.nelement, torch.Tensor.retain_grad, torch.Tensor.stride, torch.autograd.grad, torch.add, })
def test_autograd_to_mkldnn(self): # MKLDNN only supports float32 root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) def func(root): return root.to_mkldnn().to_dense() # because MKLDNN only supports float32, we need to lessen the precision. # these numbers are just empirical results that seem to work. self.assertWarnsRegex(UserWarning, 'double precision floating point', lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2)) self.assertWarnsRegex(UserWarning, 'double precision floating point', lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True, check_batched_grad=None, check_batched_forward_grad=False): assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad') # NB: check_backward_ad does not affect gradgradcheck (always True) if variant is None: self.skipTest("Skipped! Variant not implemented.") if not op.supports_dtype(dtype, torch.device(device).type): self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") def is_inplace(variant): if hasattr(variant, "__wrapped__"): return variant.__wrapped__ is op.get_inplace() return variant is op.get_inplace() include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs, small_inputs_only=is_slow_gradcheck_env()) for sample in samples: if sample.broadcasts_input and is_inplace(variant): continue # Gradcheck expects tensors as its input, but autograd actually supports tensorlists # and tensors passed as kwargs. The following creates a function that accepts just # the tensors that require grad as varargs, and then recomposes them back into the # original input. # Creates gradcheck inputs by identifying tensors requiring grad all_args = None if is_iterable_of_tensors(sample.input): all_args = chain(sample.input, sample.args, sample.kwargs.values()) else: all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) def _input_recomposition_helper(inputs, inp, input_idx): if is_iterable_of_tensors(inp): tensor_list = [] for x in inp: if isinstance(x, torch.Tensor) and x.requires_grad: tensor_list.append(inputs[input_idx]) input_idx = input_idx + 1 else: tensor_list.append(x) return tensor_list, input_idx elif isinstance(inp, torch.Tensor) and inp.requires_grad: return inputs[input_idx], input_idx + 1 else: return inp, input_idx def fn(*inputs): # Puts inputs back into sample properly positional_args = [] input_idx = 0 inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx) positional_args.append(inp) for x in sample.args: inp, input_idx = _input_recomposition_helper(inputs, x, input_idx) positional_args.append(inp) # Recreates kwargs kwargs = {} for k, v in sample.kwargs.items(): inp, input_idx = _input_recomposition_helper(inputs, v, input_idx) kwargs[k] = inp output = op.gradcheck_wrapper(variant, *positional_args, **kwargs) if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output if check == 'gradcheck': if check_batched_grad is None: check_batched_grad = op.check_batched_grad self.assertTrue(gradcheck(fn, gradcheck_args, check_batched_grad=check_batched_grad, check_grad_dtypes=True, nondet_tol=op.gradcheck_nondet_tol, fast_mode=op.gradcheck_fast_mode, check_forward_ad=check_forward_ad, check_backward_ad=check_backward_ad, check_undefined_grad=True, check_batched_forward_grad=check_batched_forward_grad)) elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") for gen_non_contig_grad_outputs in (False, True): kwargs = { "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs, "check_batched_grad": op.check_batched_gradgrad, "check_grad_dtypes": True, "nondet_tol": op.gradcheck_nondet_tol, "fast_mode": op.gradcheck_fast_mode } if check == "fwgrad_bwgrad": kwargs["check_fwd_over_rev"] = True kwargs["check_rev_over_rev"] = False kwargs["check_batched_grad"] = False kwargs["check_undefined_grad"] = False self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs)) else: self.assertTrue(False, msg="Unknown check requested!")