Beispiel #1
0
    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False):
        if variant is None:
            self.skipTest("Skipped! Variant not implemented.")
        if not op.supports_dtype(dtype, torch.device(device).type):
            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")

        def is_inplace(variant):
            if hasattr(variant, "__wrapped__"):
                return variant.__wrapped__ is op.get_inplace()
            return variant is op.get_inplace()

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)

        for sample in samples:
            if sample.broadcasts_input and is_inplace(variant):
                continue

            # Note on TensorList inputs
            #
            # gradcheck does not support TensorList inputs so here we pass TensorList
            # inputs of size n as n single Tensor inputs to gradcheck and wrap the op
            # in a function that puts the n Tensor inputs back into a TensorList
            def fn(*inputs):
                # Put tensors back into TensorList since we splat them when passing to gradcheck
                if is_iterable_of_tensors(sample.input):
                    n = len(sample.input)
                    inputs = (inputs[:n], *inputs[n:])
                output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs)
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            # Splat TensorList inputs into single Tensor inputs
            gradcheck_args = (sample.input,) if isinstance(sample.input, torch.Tensor) else tuple(sample.input)
            gradcheck_args += sample.args

            if check == 'gradcheck':
                self.assertTrue(gradcheck(fn, gradcheck_args,
                                          check_batched_grad=op.check_batched_grad,
                                          check_grad_dtypes=True,
                                          nondet_tol=op.gradcheck_nondet_tol,
                                          fast_mode=op.gradcheck_fast_mode,
                                          check_forward_ad=check_forward_ad))
            elif check == 'gradgradcheck':
                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
                self.assertTrue(gradgradcheck(fn, gradcheck_args,
                                              gen_non_contig_grad_outputs=False,
                                              check_batched_grad=op.check_batched_gradgrad,
                                              check_grad_dtypes=True,
                                              nondet_tol=op.gradcheck_nondet_tol,
                                              fast_mode=op.gradcheck_fast_mode))
                self.assertTrue(gradgradcheck(fn, gradcheck_args,
                                              gen_non_contig_grad_outputs=True,
                                              check_batched_grad=op.check_batched_gradgrad,
                                              check_grad_dtypes=True,
                                              nondet_tol=op.gradcheck_nondet_tol,
                                              fast_mode=op.gradcheck_fast_mode))
            else:
                self.assertTrue(False, msg="Unknown check requested!")
Beispiel #2
0
        def run_test(fast_mode):
            a = wrap(torch.tensor(5.0, dtype=torch.double))
            b = wrap(torch.tensor(6.0, dtype=torch.double))

            a.requires_grad = True
            b.requires_grad = True

            gradcheck(torch.add, (a, b),
                      raise_exception=False,
                      check_batched_grad=False,
                      fast_mode=fast_mode)
            gradgradcheck(torch.add, (a, b),
                          raise_exception=False,
                          check_batched_grad=False,
                          fast_mode=fast_mode)

            total_used_attrs = a.used_attrs.union(b.used_attrs)
            total_used_calls = a.used_calls.union(b.used_calls)

            # These attributes (and the functions below) may change
            # if the gradcheck implementation changes. It's best to
            # aim for attributes that may be commonly present on other
            # Tensor-likes.
            expected_used_attrs = {
                'data',
                'dtype',
                'is_floating_point',
                'is_sparse',
                'is_sparse_csr',
                'layout',
                'new_zeros',
                'numel',
                'requires_grad',
                'requires_grad_',
                'retain_grad',
                'size',
                'stride',
            }
            if fast_mode:
                expected_used_attrs.add('is_complex')
                expected_used_attrs.add('device')
            self.assertEqual(expected_used_attrs, total_used_attrs)

            expected_used_calls = {
                torch.Tensor.new_zeros,
                torch.Tensor.size,
                torch.Tensor.is_floating_point,
                torch.Tensor.numel,
                torch.Tensor.retain_grad,
                torch.Tensor.stride,
                torch.Tensor.requires_grad_,
                torch.autograd.grad,
                torch.add,
            }
            if fast_mode:
                expected_used_calls.add(torch.Tensor.is_complex)
            self.assertEqual(expected_used_calls, total_used_calls)
    def test_gradcheck(self):
        from torch.testing._internal.common_utils import gradcheck, gradgradcheck

        a = wrap(torch.tensor(5.0, dtype=torch.double))
        b = wrap(torch.tensor(6.0, dtype=torch.double))

        a.requires_grad = True
        b.requires_grad = True

        gradcheck(torch.add, (a, b),
                  raise_exception=False,
                  check_batched_grad=False)
        gradgradcheck(torch.add, (a, b),
                      raise_exception=False,
                      check_batched_grad=False)

        total_used_attrs = a.used_attrs.union(b.used_attrs)
        total_used_calls = a.used_calls.union(b.used_calls)

        # These attributes (and the functions below) may change
        # if the gradcheck implementation changes. It's best to
        # aim for attributes that may be commonly present on other
        # Tensor-likes.
        self.assertEqual(
            total_used_attrs, {
                'data',
                'dtype',
                'is_complex',
                'is_floating_point',
                'is_sparse',
                'layout',
                'nelement',
                'new_zeros',
                'requires_grad',
                'retain_grad',
                'size',
                'stride',
            })

        self.assertEqual(
            total_used_calls, {
                torch.Tensor.new_zeros,
                torch.Tensor.size,
                torch.Tensor.is_complex,
                torch.Tensor.is_floating_point,
                torch.Tensor.nelement,
                torch.Tensor.retain_grad,
                torch.Tensor.stride,
                torch.autograd.grad,
                torch.add,
            })
Beispiel #4
0
    def test_autograd_to_mkldnn(self):
        # MKLDNN only supports float32
        root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)

        def func(root):
            return root.to_mkldnn().to_dense()

        # because MKLDNN only supports float32, we need to lessen the precision.
        # these numbers are just empirical results that seem to work.
        self.assertWarnsRegex(UserWarning,
                              'double precision floating point',
                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
        self.assertWarnsRegex(UserWarning,
                              'double precision floating point',
                              lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
                      check_batched_grad=None, check_batched_forward_grad=False):
        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
        # NB: check_backward_ad does not affect gradgradcheck (always True)
        if variant is None:
            self.skipTest("Skipped! Variant not implemented.")
        if not op.supports_dtype(dtype, torch.device(device).type):
            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")

        def is_inplace(variant):
            if hasattr(variant, "__wrapped__"):
                return variant.__wrapped__ is op.get_inplace()
            return variant is op.get_inplace()

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex

        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
                                   small_inputs_only=is_slow_gradcheck_env())

        for sample in samples:
            if sample.broadcasts_input and is_inplace(variant):
                continue

            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
            #   and tensors passed as kwargs. The following creates a function that accepts just
            #   the tensors that require grad as varargs, and then recomposes them back into the
            #   original input.

            # Creates gradcheck inputs by identifying tensors requiring grad
            all_args = None
            if is_iterable_of_tensors(sample.input):
                all_args = chain(sample.input, sample.args, sample.kwargs.values())
            else:
                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))

            def _input_recomposition_helper(inputs, inp, input_idx):
                if is_iterable_of_tensors(inp):
                    tensor_list = []
                    for x in inp:
                        if isinstance(x, torch.Tensor) and x.requires_grad:
                            tensor_list.append(inputs[input_idx])
                            input_idx = input_idx + 1
                        else:
                            tensor_list.append(x)
                    return tensor_list, input_idx
                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
                    return inputs[input_idx], input_idx + 1
                else:
                    return inp, input_idx

            def fn(*inputs):
                # Puts inputs back into sample properly
                positional_args = []
                input_idx = 0
                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
                positional_args.append(inp)

                for x in sample.args:
                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
                    positional_args.append(inp)

                # Recreates kwargs
                kwargs = {}
                for k, v in sample.kwargs.items():
                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
                    kwargs[k] = inp

                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            if check == 'gradcheck':
                if check_batched_grad is None:
                    check_batched_grad = op.check_batched_grad
                self.assertTrue(gradcheck(fn, gradcheck_args,
                                          check_batched_grad=check_batched_grad,
                                          check_grad_dtypes=True,
                                          nondet_tol=op.gradcheck_nondet_tol,
                                          fast_mode=op.gradcheck_fast_mode,
                                          check_forward_ad=check_forward_ad,
                                          check_backward_ad=check_backward_ad,
                                          check_undefined_grad=True,
                                          check_batched_forward_grad=check_batched_forward_grad))
            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
                for gen_non_contig_grad_outputs in (False, True):
                    kwargs = {
                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
                        "check_batched_grad": op.check_batched_gradgrad,
                        "check_grad_dtypes": True,
                        "nondet_tol": op.gradcheck_nondet_tol,
                        "fast_mode": op.gradcheck_fast_mode
                    }
                    if check == "fwgrad_bwgrad":
                        kwargs["check_fwd_over_rev"] = True
                        kwargs["check_rev_over_rev"] = False
                        kwargs["check_batched_grad"] = False
                        kwargs["check_undefined_grad"] = False

                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
            else:
                self.assertTrue(False, msg="Unknown check requested!")