Ejemplo n.º 1
0
    def test_autograd_to_mkldnn(self):
        # MKLDNN only supports float32
        root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)

        def func(root):
            return root.to_mkldnn().to_dense()

        # because MKLDNN only supports float32, we need to lessen the precision.
        # these numbers are just empirical results that seem to work.
        self.assertWarnsRegex(UserWarning,
                              'double precision floating point',
                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
        self.assertWarnsRegex(UserWarning,
                              'double precision floating point',
                              lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
Ejemplo n.º 2
0
    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
                      check_batched_grad=None, check_batched_forward_grad=False):
        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
        # NB: check_backward_ad does not affect gradgradcheck (always True)
        if variant is None:
            self.skipTest("Skipped! Variant not implemented.")
        if not op.supports_dtype(dtype, torch.device(device).type):
            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")

        def is_inplace(variant):
            if hasattr(variant, "__wrapped__"):
                return variant.__wrapped__ is op.get_inplace()
            return variant is op.get_inplace()

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex

        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
                                   small_inputs_only=is_slow_gradcheck_env())

        for sample in samples:
            if sample.broadcasts_input and is_inplace(variant):
                continue

            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
            #   and tensors passed as kwargs. The following creates a function that accepts just
            #   the tensors that require grad as varargs, and then recomposes them back into the
            #   original input.

            # Creates gradcheck inputs by identifying tensors requiring grad
            all_args = None
            if is_iterable_of_tensors(sample.input):
                all_args = chain(sample.input, sample.args, sample.kwargs.values())
            else:
                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))

            def _input_recomposition_helper(inputs, inp, input_idx):
                if is_iterable_of_tensors(inp):
                    tensor_list = []
                    for x in inp:
                        if isinstance(x, torch.Tensor) and x.requires_grad:
                            tensor_list.append(inputs[input_idx])
                            input_idx = input_idx + 1
                        else:
                            tensor_list.append(x)
                    return tensor_list, input_idx
                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
                    return inputs[input_idx], input_idx + 1
                else:
                    return inp, input_idx

            def fn(*inputs):
                # Puts inputs back into sample properly
                positional_args = []
                input_idx = 0
                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
                positional_args.append(inp)

                for x in sample.args:
                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
                    positional_args.append(inp)

                # Recreates kwargs
                kwargs = {}
                for k, v in sample.kwargs.items():
                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
                    kwargs[k] = inp

                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            if check == 'gradcheck':
                if check_batched_grad is None:
                    check_batched_grad = op.check_batched_grad
                self.assertTrue(gradcheck(fn, gradcheck_args,
                                          check_batched_grad=check_batched_grad,
                                          check_grad_dtypes=True,
                                          nondet_tol=op.gradcheck_nondet_tol,
                                          fast_mode=op.gradcheck_fast_mode,
                                          check_forward_ad=check_forward_ad,
                                          check_backward_ad=check_backward_ad,
                                          check_undefined_grad=True,
                                          check_batched_forward_grad=check_batched_forward_grad))
            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
                for gen_non_contig_grad_outputs in (False, True):
                    kwargs = {
                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
                        "check_batched_grad": op.check_batched_gradgrad,
                        "check_grad_dtypes": True,
                        "nondet_tol": op.gradcheck_nondet_tol,
                        "fast_mode": op.gradcheck_fast_mode
                    }
                    if check == "fwgrad_bwgrad":
                        kwargs["check_fwd_over_rev"] = True
                        kwargs["check_rev_over_rev"] = False
                        kwargs["check_batched_grad"] = False
                        kwargs["check_undefined_grad"] = False

                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
            else:
                self.assertTrue(False, msg="Unknown check requested!")
Ejemplo n.º 3
0
    def _check_helper(self,
                      device,
                      dtype,
                      op,
                      variant,
                      check,
                      *,
                      check_forward_ad=False,
                      check_backward_ad=True,
                      check_batched_grad=None,
                      check_batched_forward_grad=False):
        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
        # NB: check_backward_ad does not affect gradgradcheck (always True)
        if variant is None:
            self.skipTest("Skipped! Variant not implemented.")
        if not op.supports_dtype(dtype, torch.device(device).type):
            self.skipTest(
                f"Skipped! {op.name} does not support dtype {str(dtype)}")

        def is_inplace(variant):
            if hasattr(variant, "__wrapped__"):
                return variant.__wrapped__ is op.get_inplace()
            return variant is op.get_inplace()

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
        samples = op.sample_inputs(
            device,
            dtype,
            requires_grad=True,
            include_conjugated_inputs=include_conjugated_inputs)

        for sample in samples:
            if sample.broadcasts_input and is_inplace(variant):
                continue

            # Note on TensorList inputs
            #
            # gradcheck does not support TensorList inputs so here we pass TensorList
            # inputs of size n as n single Tensor inputs to gradcheck and wrap the op
            # in a function that puts the n Tensor inputs back into a TensorList
            def fn(*inputs):
                # Put tensors back into TensorList since we splat them when passing to gradcheck
                if is_iterable_of_tensors(sample.input):
                    n = len(sample.input)
                    inputs = (inputs[:n], *inputs[n:])
                output = op.gradcheck_wrapper(variant, *inputs,
                                              **sample.kwargs)
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            # Splat TensorList inputs into single Tensor inputs
            gradcheck_args = (sample.input, ) if isinstance(
                sample.input, torch.Tensor) else tuple(sample.input)
            gradcheck_args += sample.args

            if check == 'gradcheck':
                if check_batched_grad is None:
                    check_batched_grad = op.check_batched_grad
                self.assertTrue(
                    gradcheck(
                        fn,
                        gradcheck_args,
                        check_batched_grad=check_batched_grad,
                        check_grad_dtypes=True,
                        nondet_tol=op.gradcheck_nondet_tol,
                        fast_mode=op.gradcheck_fast_mode,
                        check_forward_ad=check_forward_ad,
                        check_backward_ad=check_backward_ad,
                        check_undefined_grad=True,
                        check_batched_forward_grad=check_batched_forward_grad))
            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
                self.assertFalse(
                    check_forward_ad,
                    msg="Cannot run forward AD check for gradgradcheck")
                for gen_non_contig_grad_outputs in (False, True):
                    kwargs = {
                        "gen_non_contig_grad_outputs":
                        gen_non_contig_grad_outputs,
                        "check_batched_grad": op.check_batched_gradgrad,
                        "check_grad_dtypes": True,
                        "nondet_tol": op.gradcheck_nondet_tol,
                        "fast_mode": op.gradcheck_fast_mode
                    }
                    if check == "fwgrad_bwgrad":
                        kwargs["check_fwd_over_rev"] = True
                        kwargs["check_rev_over_rev"] = False
                        kwargs["check_batched_grad"] = False
                        kwargs["check_undefined_grad"] = False

                    self.assertTrue(gradgradcheck(fn, gradcheck_args,
                                                  **kwargs))
            else:
                self.assertTrue(False, msg="Unknown check requested!")
Ejemplo n.º 4
0
    def _test_common(
        self,
        reduction,
        device,
        dtype,
        unsafe,
        axis,
        initial_value,
        data_arr,
        lengths_arr,
        expected_arr,
        expected_grad_arr,
        check_backward,
        lengths_dtype=torch.int,
    ):
        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
        data = torch.tensor(
            data_arr,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        expected_result = torch.tensor(expected_arr,
                                       device=device,
                                       dtype=dtype)
        expected_grad = torch.tensor(expected_grad_arr,
                                     device=device,
                                     dtype=dtype)
        actual_result = torch.segment_reduce(
            data=data,
            reduce=reduction,
            lengths=lengths,
            axis=axis,
            unsafe=unsafe,
            initial=initial_value,
        )
        self.assertEqual(expected_result,
                         actual_result,
                         rtol=1e-02,
                         atol=1e-05,
                         equal_nan=True)

        if not check_backward:
            return

        # Test backward
        actual_result.sum().backward()
        self.assertEqual(expected_grad,
                         data.grad,
                         rtol=1e-02,
                         atol=1e-05,
                         equal_nan=True)

        # gradcheck does not work well with bfloat16 or fp16 cpu types
        # also there is small numerical difference with fp32
        if dtype not in [torch.half, torch.bfloat16, torch.float]:
            # gradcheck does not like "nan" input, setting to random 10
            d_non_nan = np.nan_to_num(data_arr, nan=10)
            data = torch.tensor(
                # [10 if v == float("nan") else v for v in data],
                d_non_nan,
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
            self.assertTrue(
                gradcheck(
                    lambda x: torch.segment_reduce(
                        data=x,
                        reduce=reduction,
                        lengths=lengths,
                        axis=axis,
                        unsafe=unsafe,
                        initial=initial_value,
                    ),
                    (data, ),
                ))
Ejemplo n.º 5
0
    def _test_common(
        self,
        reduction,
        device,
        dtype,
        unsafe,
        axis,
        initial_value,
        data_arr,
        lengths_arr,
        expected_arr,
        expected_grad_arr,
        check_backward,
        lengths_dtype=torch.int,
    ):
        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
        # generate offsets from lengths
        zeros_shape = list(lengths.shape)
        zeros_shape[-1] = 1
        offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths),
                            -1).cumsum_(-1)

        data = torch.tensor(
            data_arr,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        expected_result = torch.tensor(expected_arr,
                                       device=device,
                                       dtype=dtype)
        expected_grad = torch.tensor(expected_grad_arr,
                                     device=device,
                                     dtype=dtype)
        for mode in ['lengths', 'offsets']:
            segment_reduce_kwargs = dict(axis=axis,
                                         unsafe=unsafe,
                                         initial=initial_value)
            if (mode == 'lengths'):
                segment_reduce_kwargs['lengths'] = lengths
            else:
                segment_reduce_kwargs['offsets'] = offsets
            actual_result = torch.segment_reduce(data=data,
                                                 reduce=reduction,
                                                 **segment_reduce_kwargs)
            self.assertEqual(expected_result,
                             actual_result,
                             rtol=1e-02,
                             atol=1e-05,
                             equal_nan=True)

            if not check_backward:
                return

            # Test backward
            actual_result.sum().backward()
            self.assertEqual(expected_grad,
                             data.grad,
                             rtol=1e-02,
                             atol=1e-05,
                             equal_nan=True)
            data = data.clone().detach().requires_grad_(True)

            # gradcheck does not work well with bfloat16 or fp16 cpu types
            # also there is small numerical difference with fp32
            if dtype not in [torch.half, torch.bfloat16, torch.float]:
                # gradcheck does not like "nan" input, setting to random 10
                d_non_nan = np.nan_to_num(data_arr, nan=10)
                new_data = torch.tensor(
                    # [10 if v == float("nan") else v for v in data],
                    d_non_nan,
                    device=device,
                    dtype=dtype,
                    requires_grad=True,
                )
                self.assertTrue(
                    gradcheck(
                        lambda x: torch.segment_reduce(
                            data=x, reduce=reduction, **segment_reduce_kwargs),
                        (new_data, ),
                    ))
Ejemplo n.º 6
0
    def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
        val_dtype, length_dtype = dtypes
        # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
        tests = [
            {
                'src': [1, 2, 3, 4, 5, 6],
                'index': [0, 0, 1, 1, 1, 3],
                'indptr': [0, 2, 5, 5, 6],
                'sum': [3, 12, 0, 6],
                'prod': [2, 60, 1, 6],
                'mean': [1.5, 4, float('nan'), 6],
                'min': [1, 3, float('inf'), 6],
                'max': [2, 5, -float('inf'), 6],
            },
            {
                'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
                'index': [0, 0, 1, 1, 1, 3],
                'indptr': [0, 2, 5, 5, 6],
                'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
                'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
                'mean': [[2, 3], [7, 8], [float('nan'),
                                          float('nan')], [11, 12]],
                'min': [[1, 2], [5, 6], [float('inf'),
                                         float('inf')], [11, 12]],
                'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')],
                        [11, 12]],
            },
            {
                'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
                'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
                'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
                'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
                'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
                'mean': [[2, 7, float('nan'), 11], [4, 9, 12,
                                                    float('nan')]],
                'min': [[1, 5, float('inf'), 11], [2, 8, 12,
                                                   float('inf')]],
                'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
            },
            {
                'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12,
                                                                      13]]],
                'index': [[0, 0, 1], [0, 2, 2]],
                'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
                'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
                'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120,
                                                                     143]]],
                'mean': [[[2, 3], [5, 6], [float('nan'),
                                           float('nan')]],
                         [[7, 9], [float('nan'), float('nan')], [11, 12]]],
                'min': [[[1, 2], [5, 6], [float('inf'),
                                          float('inf')]],
                        [[7, 9], [float('inf'), float('inf')], [10, 11]]],
                'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
                        [[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
            },
            {
                'src': [[1, 3], [2, 4]],
                'index': [[0, 0], [0, 0]],
                'indptr': [[0, 2], [0, 2]],
                'sum': [[4], [6]],
                'prod': [[3], [8]],
                'mean': [[2], [3]],
                'min': [[1], [2]],
                'max': [[3], [4]],
            },
            {
                'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
                'index': [[0, 0], [0, 0]],
                'indptr': [[0, 2], [0, 2]],
                'sum': [[[4, 4]], [[6, 6]]],
                'prod': [[[3, 3]], [[8, 8]]],
                'mean': [[[2, 2]], [[3, 3]]],
                'min': [[[1, 1]], [[2, 2]]],
                'max': [[[3, 3]], [[4, 4]]],
            },
        ]
        for test in tests:
            data = torch.tensor(test['src'],
                                dtype=val_dtype,
                                device=device,
                                requires_grad=True)
            indptr = torch.tensor(test['indptr'],
                                  dtype=length_dtype,
                                  device=device)
            dim = indptr.ndim - 1
            # calculate lengths from indptr
            lengths = torch.diff(indptr, dim=dim)
            expected = torch.tensor(test[reduce],
                                    dtype=val_dtype,
                                    device=device)

            actual_result = torch.segment_reduce(
                data=data,
                reduce=reduce,
                lengths=lengths,
                axis=dim,
                unsafe=True,
            )
            self.assertEqual(actual_result, expected)

            # test offsets
            actual_result = torch.segment_reduce(
                data=data,
                reduce=reduce,
                offsets=indptr,
                axis=dim,
                unsafe=True,
            )
            self.assertEqual(actual_result, expected)

            if val_dtype == torch.float64:

                def fn(x, mode='lengths'):
                    initial = 1
                    # supply initial values to prevent gradcheck from failing for 0 length segments
                    # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
                    if reduce == 'min':
                        initial = 1000
                    elif reduce == 'max':
                        initial = -1000
                    segment_reduce_args = {x, reduce}
                    segment_reduce_kwargs = dict(axis=dim,
                                                 unsafe=True,
                                                 initial=initial)
                    if mode == 'lengths':
                        segment_reduce_kwargs[mode] = lengths
                    elif mode == 'offsets':
                        segment_reduce_kwargs[mode] = indptr
                    return torch.segment_reduce(*segment_reduce_args,
                                                **segment_reduce_kwargs)

                self.assertTrue(
                    gradcheck(partial(fn, mode='lengths'),
                              (data.clone().detach().requires_grad_(True))))
                self.assertTrue(
                    gradcheck(partial(fn, mode='offsets'),
                              (data.clone().detach().requires_grad_(True))))
Ejemplo n.º 7
0
    def _check_helper(self, device, dtype, op, variant, check):
        if variant is None:
            self.skipTest("Skipped! Variant not implemented.")
        if not op.supports_dtype(dtype, torch.device(device).type):
            self.skipTest(
                f"Skipped! {op.name} does not support dtype {str(dtype)}")

        def is_inplace(variant):
            if hasattr(variant, "__wrapped__"):
                return variant.__wrapped__ is op.get_inplace()
            return variant is op.get_inplace()

        samples = op.sample_inputs(device, dtype, requires_grad=True)
        for sample in samples:
            if sample.broadcasts_input and is_inplace(variant):
                continue

            # Note on TensorList inputs
            #
            # gradcheck does not support TensorList inputs so here we pass TensorList
            # inputs of size n as n single Tensor inputs to gradcheck and wrap the op
            # in a function that puts the n Tensor inputs back into a TensorList
            def fn(*inputs):
                # Put tensors back into TensorList since we splat them when passing to gradcheck
                if is_iterable_of_tensors(sample.input):
                    n = len(sample.input)
                    inputs = (inputs[:n], *inputs[n:])
                output = op.gradcheck_wrapper(variant, *inputs,
                                              **sample.kwargs)
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            # Splat TensorList inputs into single Tensor inputs
            gradcheck_args = (sample.input, ) if isinstance(
                sample.input, torch.Tensor) else tuple(sample.input)
            gradcheck_args += sample.args

            if check == 'gradcheck':
                self.assertTrue(
                    gradcheck(fn,
                              gradcheck_args,
                              check_batched_grad=op.check_batched_grad,
                              check_grad_dtypes=True,
                              nondet_tol=op.gradcheck_nondet_tol,
                              fast_mode=op.gradcheck_fast_mode))
            elif check == 'gradgradcheck':
                self.assertTrue(
                    gradgradcheck(fn,
                                  gradcheck_args,
                                  gen_non_contig_grad_outputs=False,
                                  check_batched_grad=op.check_batched_gradgrad,
                                  check_grad_dtypes=True,
                                  nondet_tol=op.gradcheck_nondet_tol,
                                  fast_mode=op.gradcheck_fast_mode))
                self.assertTrue(
                    gradgradcheck(fn,
                                  gradcheck_args,
                                  gen_non_contig_grad_outputs=True,
                                  check_batched_grad=op.check_batched_gradgrad,
                                  check_grad_dtypes=True,
                                  nondet_tol=op.gradcheck_nondet_tol,
                                  fast_mode=op.gradcheck_fast_mode))
            else:
                self.assertTrue(False, msg="Unknown check requested!")
Ejemplo n.º 8
0
    def _test_simple_1d(self, reduction, device, dtype, unsafe, axis):
        lengths = torch.tensor([1, 2, 3, 0], device=device)
        data = torch.tensor(
            [1, float("nan"), 3, 4, 5, 5],
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        initial_value = 0
        if reduction == "max":
            expected_result = torch.tensor(
                [1, float("nan"), 5, initial_value], device=device, dtype=dtype
            )
            expected_grad = torch.tensor(
                [1, 1, 0, 0, 0.5, 0.5], device=device, dtype=dtype
            )
        elif reduction == "mean":
            expected_result = torch.tensor(
                [1, float("nan"), 4.666, initial_value], device=device, dtype=dtype
            )
            expected_grad = torch.tensor(
                [1.0, 0.5, 0.5, 0.333, 0.333, 0.333], device=device, dtype=dtype
            )
        actual_result = torch.segment_reduce(
            data=data,
            reduce=reduction,
            lengths=lengths,
            axis=axis,
            unsafe=unsafe,
            initial=initial_value,
        )
        self.assertEqual(
            expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
        )

        # TODO: Remove this check once cuda backward support is implemented
        if data.is_cuda:
            return

        # Test backward
        actual_result.sum().backward()
        self.assertEqual(
            expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
        )

        # gradcheck does not work well with bfloat16 or fp16 cpu types
        # also there is small numerical difference with fp32
        if dtype not in [torch.half, torch.bfloat16, torch.float]:
            # gradcheck does not like "nan" input
            data = torch.tensor(
                [1, 10, 3, 4, 5, 5],
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
            self.assertTrue(
                gradcheck(
                    lambda x: torch.segment_reduce(
                        data=x,
                        reduce=reduction,
                        lengths=lengths,
                        axis=axis,
                        unsafe=unsafe,
                        initial=initial_value,
                    ),
                    (data,),
                )
            )
    def _test_max_simple_1d(self, device, dtype, unsafe, axis):
        lengths = torch.tensor([1, 2, 3, 0], device=device)
        data = torch.tensor(
            [1, float("nan"), 3, 4, 5, 5],
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        initial_value = 0
        expected_result = torch.tensor([1, float("nan"), 5, initial_value],
                                       device=device,
                                       dtype=dtype)
        actual_result = torch.segment_reduce(
            data=data,
            reduce="max",
            lengths=lengths,
            axis=axis,
            unsafe=unsafe,
            initial=initial_value,
        )
        self.assertEqual(expected_result,
                         actual_result,
                         rtol=1e-03,
                         atol=1e-05,
                         equal_nan=True)

        # Backward is only supported for cpu tensors for now. Return early if cuda
        if data.is_cuda:
            return

        # Test backward
        expected_grad = torch.tensor([1, 1, 0, 0, 0.5, 0.5],
                                     device=device,
                                     dtype=dtype)
        actual_result.sum().backward()
        self.assertEqual(expected_grad,
                         data.grad,
                         rtol=1e-03,
                         atol=1e-05,
                         equal_nan=True)

        # gradcheck does not work well with bfloat16 or fp16 cpu types
        # also there is small numerical difference with fp32
        if dtype not in [torch.half, torch.bfloat16, torch.float]:
            # gradcheck does not like "nan" input
            data = torch.tensor(
                [1, 10, 3, 4, 5, 5],
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
            self.assertTrue(
                gradcheck(
                    lambda x: torch.segment_reduce(
                        data=x,
                        reduce="max",
                        lengths=lengths,
                        axis=axis,
                        unsafe=unsafe,
                        initial=initial_value,
                    ),
                    (data, ),
                ))