def test_unsafe_flag(self, device, dtype):
        length_type = dtype
        lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
        data = torch.arange(6, dtype=torch.float, device=device)

        # test for error on 1-D lenghts
        with self.assertRaisesRegex(RuntimeError,
                                    "Expected all rows of lengths along axis"):
            torch.segment_reduce(data,
                                 'sum',
                                 lengths=lengths,
                                 axis=0,
                                 unsafe=False)

        # test for error on multi-D lengths
        nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]],
                                  dtype=length_type,
                                  device=device)
        nd_data = torch.arange(12, dtype=torch.float,
                               device=device).reshape(2, 6)
        with self.assertRaisesRegex(RuntimeError,
                                    "Expected all rows of lengths along axis"):
            torch.segment_reduce(nd_data,
                                 'sum',
                                 lengths=nd_lengths,
                                 axis=1,
                                 unsafe=False)
Exemple #2
0
 def test_max_simple_1d(self, device, dtype):
     lengths = torch.tensor([1, 2, 3], device=device)
     data = torch.tensor([1, float("nan"), 3, 4, 5, 6],
                         device=device,
                         dtype=dtype)
     expected_result = torch.tensor([1, float("nan"), 6],
                                    device=device,
                                    dtype=dtype)
     actual_result = torch.segment_reduce(data=data,
                                          reduce="max",
                                          lengths=lengths,
                                          axis=0,
                                          unsafe=False)
     self.assertEqual(expected_result,
                      actual_result,
                      rtol=1e-03,
                      atol=1e-05,
                      equal_nan=True)
     actual_result = torch.segment_reduce(data=data,
                                          reduce="max",
                                          lengths=lengths,
                                          axis=-1,
                                          unsafe=False)
     self.assertEqual(expected_result,
                      actual_result,
                      rtol=1e-03,
                      atol=1e-05,
                      equal_nan=True)
Exemple #3
0
    def _test_max_simple_1d(self, device, dtype, unsafe, axis):
        lengths = torch.tensor([1, 2, 3], device=device)
        data = torch.tensor(
            [1, float("nan"), 3, 4, 5, 5],
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        expected_result = torch.tensor([1, float("nan"), 5],
                                       device=device,
                                       dtype=dtype)
        actual_result = torch.segment_reduce(data=data,
                                             reduce="max",
                                             lengths=lengths,
                                             axis=axis,
                                             unsafe=unsafe)
        self.assertEqual(expected_result,
                         actual_result,
                         rtol=1e-03,
                         atol=1e-05,
                         equal_nan=True)

        # Backward is only supported for cpu tensors for now. Return early if cuda
        if data.is_cuda:
            return

        # Test backward
        expected_grad = torch.tensor([1, 1, 0, 0, 0.5, 0.5],
                                     device=device,
                                     dtype=dtype)
        actual_result.sum().backward()
        self.assertEqual(expected_grad,
                         data.grad,
                         rtol=1e-03,
                         atol=1e-05,
                         equal_nan=True)

        # gradcheck does not work well with bfloat16 or fp16 cpu types
        # also there is small numerical difference with fp32
        if dtype not in [torch.half, torch.bfloat16, torch.float]:
            # gradcheck does not like "nan" input
            data = torch.tensor(
                [1, 10, 3, 4, 5, 5],
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
            self.assertTrue(
                gradcheck(
                    lambda x: torch.segment_reduce(data=x,
                                                   reduce="max",
                                                   lengths=lengths,
                                                   axis=axis,
                                                   unsafe=unsafe),
                    (data, ),
                ))
Exemple #4
0
 def fn(x):
     initial = 1
     # supply initial values to prevent gradcheck from failing for 0 length segments
     # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
     if reduce == 'min':
         initial = 1000
     elif reduce == 'max':
         initial = -1000
     return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial)
 def fn(x, mode='lengths'):
     initial = 1
     # supply initial values to prevent gradcheck from failing for 0 length segments
     # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
     if reduce == 'min':
         initial = 1000
     elif reduce == 'max':
         initial = -1000
     segment_reduce_args = {x, reduce}
     segment_reduce_kwargs = dict(axis=dim,
                                  unsafe=True,
                                  initial=initial)
     if mode == 'lengths':
         segment_reduce_kwargs[mode] = lengths
     elif mode == 'offsets':
         segment_reduce_kwargs[mode] = indptr
     return torch.segment_reduce(*segment_reduce_args,
                                 **segment_reduce_kwargs)
Exemple #6
0
    def _test_common(
        self,
        reduction,
        device,
        dtype,
        unsafe,
        axis,
        initial_value,
        data_arr,
        lengths_arr,
        expected_arr,
        expected_grad_arr,
        check_backward,
        lengths_dtype=torch.int,
    ):
        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
        data = torch.tensor(
            data_arr,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        expected_result = torch.tensor(expected_arr,
                                       device=device,
                                       dtype=dtype)
        expected_grad = torch.tensor(expected_grad_arr,
                                     device=device,
                                     dtype=dtype)
        actual_result = torch.segment_reduce(
            data=data,
            reduce=reduction,
            lengths=lengths,
            axis=axis,
            unsafe=unsafe,
            initial=initial_value,
        )
        self.assertEqual(expected_result,
                         actual_result,
                         rtol=1e-02,
                         atol=1e-05,
                         equal_nan=True)

        if not check_backward:
            return

        # Test backward
        actual_result.sum().backward()
        self.assertEqual(expected_grad,
                         data.grad,
                         rtol=1e-02,
                         atol=1e-05,
                         equal_nan=True)

        # gradcheck does not work well with bfloat16 or fp16 cpu types
        # also there is small numerical difference with fp32
        if dtype not in [torch.half, torch.bfloat16, torch.float]:
            # gradcheck does not like "nan" input, setting to random 10
            d_non_nan = np.nan_to_num(data_arr, nan=10)
            data = torch.tensor(
                # [10 if v == float("nan") else v for v in data],
                d_non_nan,
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
            self.assertTrue(
                gradcheck(
                    lambda x: torch.segment_reduce(
                        data=x,
                        reduce=reduction,
                        lengths=lengths,
                        axis=axis,
                        unsafe=unsafe,
                        initial=initial_value,
                    ),
                    (data, ),
                ))
    def _test_common(
        self,
        reduction,
        device,
        dtype,
        unsafe,
        axis,
        initial_value,
        data_arr,
        lengths_arr,
        expected_arr,
        expected_grad_arr,
        check_backward,
        lengths_dtype=torch.int,
    ):
        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
        # generate offsets from lengths
        zeros_shape = list(lengths.shape)
        zeros_shape[-1] = 1
        offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths),
                            -1).cumsum_(-1)

        data = torch.tensor(
            data_arr,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        expected_result = torch.tensor(expected_arr,
                                       device=device,
                                       dtype=dtype)
        expected_grad = torch.tensor(expected_grad_arr,
                                     device=device,
                                     dtype=dtype)
        for mode in ['lengths', 'offsets']:
            segment_reduce_kwargs = dict(axis=axis,
                                         unsafe=unsafe,
                                         initial=initial_value)
            if (mode == 'lengths'):
                segment_reduce_kwargs['lengths'] = lengths
            else:
                segment_reduce_kwargs['offsets'] = offsets
            actual_result = torch.segment_reduce(data=data,
                                                 reduce=reduction,
                                                 **segment_reduce_kwargs)
            self.assertEqual(expected_result,
                             actual_result,
                             rtol=1e-02,
                             atol=1e-05,
                             equal_nan=True)

            if not check_backward:
                return

            # Test backward
            actual_result.sum().backward()
            self.assertEqual(expected_grad,
                             data.grad,
                             rtol=1e-02,
                             atol=1e-05,
                             equal_nan=True)
            data = data.clone().detach().requires_grad_(True)

            # gradcheck does not work well with bfloat16 or fp16 cpu types
            # also there is small numerical difference with fp32
            if dtype not in [torch.half, torch.bfloat16, torch.float]:
                # gradcheck does not like "nan" input, setting to random 10
                d_non_nan = np.nan_to_num(data_arr, nan=10)
                new_data = torch.tensor(
                    # [10 if v == float("nan") else v for v in data],
                    d_non_nan,
                    device=device,
                    dtype=dtype,
                    requires_grad=True,
                )
                self.assertTrue(
                    gradcheck(
                        lambda x: torch.segment_reduce(
                            data=x, reduce=reduction, **segment_reduce_kwargs),
                        (new_data, ),
                    ))
    def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
        val_dtype, length_dtype = dtypes
        # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
        tests = [
            {
                'src': [1, 2, 3, 4, 5, 6],
                'index': [0, 0, 1, 1, 1, 3],
                'indptr': [0, 2, 5, 5, 6],
                'sum': [3, 12, 0, 6],
                'prod': [2, 60, 1, 6],
                'mean': [1.5, 4, float('nan'), 6],
                'min': [1, 3, float('inf'), 6],
                'max': [2, 5, -float('inf'), 6],
            },
            {
                'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
                'index': [0, 0, 1, 1, 1, 3],
                'indptr': [0, 2, 5, 5, 6],
                'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
                'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
                'mean': [[2, 3], [7, 8], [float('nan'),
                                          float('nan')], [11, 12]],
                'min': [[1, 2], [5, 6], [float('inf'),
                                         float('inf')], [11, 12]],
                'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')],
                        [11, 12]],
            },
            {
                'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
                'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
                'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
                'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
                'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
                'mean': [[2, 7, float('nan'), 11], [4, 9, 12,
                                                    float('nan')]],
                'min': [[1, 5, float('inf'), 11], [2, 8, 12,
                                                   float('inf')]],
                'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
            },
            {
                'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12,
                                                                      13]]],
                'index': [[0, 0, 1], [0, 2, 2]],
                'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
                'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
                'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120,
                                                                     143]]],
                'mean': [[[2, 3], [5, 6], [float('nan'),
                                           float('nan')]],
                         [[7, 9], [float('nan'), float('nan')], [11, 12]]],
                'min': [[[1, 2], [5, 6], [float('inf'),
                                          float('inf')]],
                        [[7, 9], [float('inf'), float('inf')], [10, 11]]],
                'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
                        [[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
            },
            {
                'src': [[1, 3], [2, 4]],
                'index': [[0, 0], [0, 0]],
                'indptr': [[0, 2], [0, 2]],
                'sum': [[4], [6]],
                'prod': [[3], [8]],
                'mean': [[2], [3]],
                'min': [[1], [2]],
                'max': [[3], [4]],
            },
            {
                'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
                'index': [[0, 0], [0, 0]],
                'indptr': [[0, 2], [0, 2]],
                'sum': [[[4, 4]], [[6, 6]]],
                'prod': [[[3, 3]], [[8, 8]]],
                'mean': [[[2, 2]], [[3, 3]]],
                'min': [[[1, 1]], [[2, 2]]],
                'max': [[[3, 3]], [[4, 4]]],
            },
        ]
        for test in tests:
            data = torch.tensor(test['src'],
                                dtype=val_dtype,
                                device=device,
                                requires_grad=True)
            indptr = torch.tensor(test['indptr'],
                                  dtype=length_dtype,
                                  device=device)
            dim = indptr.ndim - 1
            # calculate lengths from indptr
            lengths = torch.diff(indptr, dim=dim)
            expected = torch.tensor(test[reduce],
                                    dtype=val_dtype,
                                    device=device)

            actual_result = torch.segment_reduce(
                data=data,
                reduce=reduce,
                lengths=lengths,
                axis=dim,
                unsafe=True,
            )
            self.assertEqual(actual_result, expected)

            # test offsets
            actual_result = torch.segment_reduce(
                data=data,
                reduce=reduce,
                offsets=indptr,
                axis=dim,
                unsafe=True,
            )
            self.assertEqual(actual_result, expected)

            if val_dtype == torch.float64:

                def fn(x, mode='lengths'):
                    initial = 1
                    # supply initial values to prevent gradcheck from failing for 0 length segments
                    # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
                    if reduce == 'min':
                        initial = 1000
                    elif reduce == 'max':
                        initial = -1000
                    segment_reduce_args = {x, reduce}
                    segment_reduce_kwargs = dict(axis=dim,
                                                 unsafe=True,
                                                 initial=initial)
                    if mode == 'lengths':
                        segment_reduce_kwargs[mode] = lengths
                    elif mode == 'offsets':
                        segment_reduce_kwargs[mode] = indptr
                    return torch.segment_reduce(*segment_reduce_args,
                                                **segment_reduce_kwargs)

                self.assertTrue(
                    gradcheck(partial(fn, mode='lengths'),
                              (data.clone().detach().requires_grad_(True))))
                self.assertTrue(
                    gradcheck(partial(fn, mode='offsets'),
                              (data.clone().detach().requires_grad_(True))))
Exemple #9
0
    def _test_simple_1d(self, reduction, device, dtype, unsafe, axis):
        lengths = torch.tensor([1, 2, 3, 0], device=device)
        data = torch.tensor(
            [1, float("nan"), 3, 4, 5, 5],
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        initial_value = 0
        if reduction == "max":
            expected_result = torch.tensor(
                [1, float("nan"), 5, initial_value], device=device, dtype=dtype
            )
            expected_grad = torch.tensor(
                [1, 1, 0, 0, 0.5, 0.5], device=device, dtype=dtype
            )
        elif reduction == "mean":
            expected_result = torch.tensor(
                [1, float("nan"), 4.666, initial_value], device=device, dtype=dtype
            )
            expected_grad = torch.tensor(
                [1.0, 0.5, 0.5, 0.333, 0.333, 0.333], device=device, dtype=dtype
            )
        actual_result = torch.segment_reduce(
            data=data,
            reduce=reduction,
            lengths=lengths,
            axis=axis,
            unsafe=unsafe,
            initial=initial_value,
        )
        self.assertEqual(
            expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
        )

        # TODO: Remove this check once cuda backward support is implemented
        if data.is_cuda:
            return

        # Test backward
        actual_result.sum().backward()
        self.assertEqual(
            expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
        )

        # gradcheck does not work well with bfloat16 or fp16 cpu types
        # also there is small numerical difference with fp32
        if dtype not in [torch.half, torch.bfloat16, torch.float]:
            # gradcheck does not like "nan" input
            data = torch.tensor(
                [1, 10, 3, 4, 5, 5],
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
            self.assertTrue(
                gradcheck(
                    lambda x: torch.segment_reduce(
                        data=x,
                        reduce=reduction,
                        lengths=lengths,
                        axis=axis,
                        unsafe=unsafe,
                        initial=initial_value,
                    ),
                    (data,),
                )
            )