示例#1
0
class TestPythonJiterator(TestCase):
    @skipCUDAIfRocm
    @parametrize("shape_strides", [
        (([3, 3], [3, 1]), ([3, 3], [3, 1])),  # contiguous
    ])
    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
                     all_types_and_complex_and(torch.half, torch.bfloat16)))
    def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])

        a = a_buffer.as_strided(*shape_strides[0])
        b = b_buffer.as_strided(*shape_strides[1])

        expected = ref_fn(a, b)
        result = jitted_fn(a, b)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
    @skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to "
                "compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.")
    @parametrize("shape_strides", [
        (([3, 3], [1, 3]), ([3, 1], [1, 3])),  # non-contiguous
    ])
    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
                     all_types_and_complex_and(torch.half, torch.bfloat16)))
    def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])

        a = a_buffer.as_strided(*shape_strides[0])
        b = b_buffer.as_strided(*shape_strides[1])

        expected = ref_fn(a, b)
        result = jitted_fn(a, b)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
    @parametrize("alpha", [-1, 2.0, None])
    @parametrize("beta", [3, -4.2, None])
    @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
    def test_extra_args(self, device, dtype, alpha, beta):
        a = torch.rand(3, device=device).mul(10).type(dtype)
        b = torch.rand(3, device=device).mul(10).type(dtype)

        extra_args = {}
        if alpha is not None:
            extra_args["alpha"] = alpha
        if beta is not None:
            extra_args["beta"] = beta

        expected = ref_fn(a, b, **extra_args)
        result = jitted_fn(a, b, **extra_args)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @parametrize("is_train", [True, False])
    def test_bool_extra_args(self, device, is_train):
        code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
        jitted_fn = create_jit_fn(code_string, is_train=False)

        def ref_fn(x, mask, is_train):
            return x * mask if is_train else x

        a = torch.rand(3, device=device)
        b = torch.rand(3, device=device)

        expected = ref_fn(a, b, is_train=is_train)
        result = jitted_fn(a, b, is_train=is_train)
        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    def test_multiple_functors(self, device):
        code_string = '''
        template <typename T> T fn(T x, T mask) { return x * mask; }
        template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
        '''
        jitted_fn = create_jit_fn(code_string)

        def ref_fn(x, mask, y):
            return x * mask + y

        a = torch.rand(3, device=device)
        b = torch.rand(3, device=device)
        c = torch.rand(3, device=device)

        expected = ref_fn(a, b, c)
        result = jitted_fn(a, b, c)
        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @parametrize("num_inputs", [1, 5, 8])
    def test_various_num_inputs(self, num_inputs):
        inputs = []
        for i in range(num_inputs):
            inputs.append(torch.rand(3, device='cuda').mul(10))

        input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
        function_body = "+".join([f"i{i}" for i in range(num_inputs)])
        code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
        jitted_fn = create_jit_fn(code_string)

        def ref_fn(*inputs):
            return torch.sum(torch.stack(inputs), dim=0)

        expected = ref_fn(*inputs)
        result = jitted_fn(*inputs)

        self.assertEqual(expected, result)

    @skipCUDAIfRocm
    @parametrize("num_outputs", [1, 4, 8])
    def test_various_num_outputs(self, num_outputs):
        input = torch.rand(3, device='cuda')

        output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
        function_body = ""
        for i in range(num_outputs):
            function_body += f"out{i} = input + {i};\n"
        code_string = f"template <typename T> T my_kernel(T input, {output_string}) {{ {function_body} }}"

        jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)

        def ref_fn(input):
            outputs = []
            for i in range(num_outputs):
                outputs.append(input + i)

            if num_outputs == 1:
                return outputs[0]
            return tuple(outputs)

        expected = ref_fn(input)
        result = jitted_fn(input)

        for i in range(num_outputs):
            self.assertEqual(expected[i], result[i])

    @skipCUDAIfRocm
    @parametrize("code_string", [
        "template <typename T> T my _kernel(T x) { return x; }",
        "template <typename T> Tmy_kernel(T x) { return x; }",
    ])
    def test_invalid_function_name(self, code_string):
        with self.assertRaises(Exception):
            jitted_fn = create_jit_fn(code_string)
示例#2
0
class TestScatterGather(TestCase):
    # Fills an index tensor with valid indices
    def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o, unique_indices=True):
        for i in range(1 if dim == 0 else m):
            for j in range(1 if dim == 1 else n):
                for k in range(1 if dim == 2 else o):
                    ii = [i, j, k]
                    ii[dim] = slice(0, idx.size(dim) + 1)
                    if unique_indices:
                        idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
                    else:
                        idx[tuple(ii)] = torch.randint(dim_size, (elems_per_row,))

    @dtypes(torch.float32, torch.complex64)
    def test_gather(self, device, dtype):
        m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
        elems_per_row = random.randint(1, 10)
        dim = random.randrange(3)

        src = make_tensor((m, n, o), device=device, dtype=dtype)
        idx_size = [m, n, o]
        idx_size[dim] = elems_per_row
        idx = make_tensor(idx_size, device=device, dtype=torch.long)
        self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)

        actual = torch.gather(src, dim, idx)
        expected = torch.zeros(idx_size, device=device, dtype=dtype)
        for i in range(idx_size[0]):
            for j in range(idx_size[1]):
                for k in range(idx_size[2]):
                    ii = [i, j, k]
                    ii[dim] = idx[i, j, k]
                    expected[i, j, k] = src[tuple(ii)]
        self.assertEqual(actual, expected, atol=0, rtol=0)

        # Guarded because torch.max isn't defined for complex types
        if not dtype.is_complex:
            src = make_tensor((3, 4, 5), device=device, dtype=dtype)
            expected, idx = src.max(2, True)
            actual = torch.gather(src, 2, idx)
            self.assertEqual(actual, expected, atol=0, rtol=0)

    @dtypes(torch.bool)
    def test_gather_bool(self, device, dtype):
        src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)
        idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
        actual = torch.gather(src, 1, idx)
        expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype)
        self.assertEqual(actual, expected, atol=0, rtol=0)

    @parametrize("sparse_grad", [False, True])
    @dtypes(torch.float32, torch.float64)
    def test_gather_backward_with_empty_index_tensor(self, device, dtype, sparse_grad):
        dim = -1
        input = torch.rand([10, 5], dtype=dtype, device=device, requires_grad=True)
        index = torch.randint(0, 2, [3, 0], dtype=torch.int64, device=device)
        res = torch.gather(input, dim, index, sparse_grad=sparse_grad)
        res.sum().backward()
        grad = input.grad.to_dense() if sparse_grad else input.grad
        expected_grad = torch.zeros_like(input, requires_grad=False)
        self.assertEqual(grad, expected_grad, atol=0, rtol=0)

    def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction,
                           unique_indices=True, include_self=True):
        m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
        elems_per_row = random.randint(1, 10)
        dim = random.randrange(3)

        idx_size = [m, n, o]
        idx_size[dim] = elems_per_row
        idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
        self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o, unique_indices)

        if is_scalar:
            src = random.random()
        else:
            src_size = [random.randint(1, 5) + s for s in idx_size]
            src = make_tensor(tuple(src_size), device=device, dtype=dtype)

        base = make_tensor((m, n, o), device=device, dtype=dtype)
        if reduction is not None:
            if fn is torch.Tensor.scatter_reduce_:
                actual = fn(base.clone(), dim, idx, src, reduce=reduction, include_self=include_self)
            else:
                actual = fn(base.clone(), dim, idx, src, reduce=reduction)
        else:
            actual = fn(base.clone(), dim, idx, src)

        expected = base.clone()
        counts = torch.zeros(base.shape, dtype=torch.long, device=device) + include_self
        for i in range(idx_size[0]):
            for j in range(idx_size[1]):
                for k in range(idx_size[2]):
                    ii = [i, j, k]
                    ii[dim] = idx[i, j, k]
                    if fn is torch.Tensor.scatter_add_:
                        expected[tuple(ii)] += src[i, j, k]
                    else:
                        # method may be 'scatter_', 'scatter', 'scatter_reduce'
                        # or 'scatter_reduce_', the former two might have a reduction argument
                        # while the latter two always do
                        value = src if is_scalar else src[i, j, k]

                        if ((not include_self) and counts[tuple(ii)] == 0):
                            expected[tuple(ii)] = value
                        else:
                            if reduction == "add" or reduction == "sum":
                                expected[tuple(ii)] += value
                            elif reduction == "multiply" or reduction == "prod":
                                expected[tuple(ii)] *= value
                            elif reduction == "amax":
                                expected[tuple(ii)] = max(expected[tuple(ii)], value)
                            elif reduction == "amin":
                                expected[tuple(ii)] = min(expected[tuple(ii)], value)
                            elif reduction == "mean":
                                expected[tuple(ii)] += value
                            else:
                                expected[tuple(ii)] = value

                        counts[tuple(ii)] += 1

        if (reduction == "mean"):
            counts.masked_fill_(counts == 0, 1)
            if (dtype.is_floating_point or dtype.is_complex):
                expected /= counts
            else:
                expected.div_(counts, rounding_mode="floor")

        self.assertEqual(actual, expected, atol=0, rtol=0)

        # Tests empty index
        dst = make_tensor((2, 2), device=device, dtype=dtype)
        idx = torch.tensor((), device=device, dtype=torch.long)
        src = make_tensor((2, 2), device=device, dtype=dtype)
        if reduction is not None:
            actual = fn(dst, 0, idx, src, reduce=reduction)
        else:
            actual = fn(dst, 0, idx, src)
        self.assertEqual(actual, dst, atol=0, rtol=0)

    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter_(self, device, dtype):
        self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
                                is_scalar=False, reduction=None)

    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter__scalar(self, device, dtype):
        self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
                                is_scalar=True, reduction=None)

    # FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
    @toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
    @dtypesIfCUDA(torch.float16, torch.float32)
    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter__reductions(self, device, dtype):
        for reduction in ("add", "multiply"):
            self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
                                    is_scalar=False, reduction=reduction)
            self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
                                    is_scalar=True, reduction=reduction)

    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter_add_(self, device, dtype):
        self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
                                is_scalar=False, reduction=None)

    @dtypes(torch.float32)
    def test_scatter_add_mult_index_base(self, device, dtype):
        m, n = 30, 40
        idx = torch.zeros(m, n, device=device, dtype=torch.long)
        src = torch.ones(m, n, device=device, dtype=dtype)
        res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
        res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)

        self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
        self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)

    # FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA)
    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
    def test_scatter_reduce_sum(self, device, dtype):
        for include_self in (True, False):
            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
                                    is_scalar=False, reduction='sum', unique_indices=False,
                                    include_self=include_self)

    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
    def test_scatter_reduce_prod(self, device, dtype):
        for include_self in (True, False):
            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
                                    is_scalar=False, reduction='prod', unique_indices=False,
                                    include_self=include_self)

    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
    def test_scatter_reduce_mean(self, device, dtype):
        for include_self in (True, False):
            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
                                    is_scalar=False, reduction='mean', unique_indices=False,
                                    include_self=include_self)

    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
    def test_scatter_reduce_amax(self, device, dtype):
        for include_self in (True, False):
            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
                                    is_scalar=False, reduction='amax', unique_indices=False,
                                    include_self=include_self)
            # simple test for nan/inf propagation
            if (dtype.is_floating_point):
                input = torch.zeros(3, device=device, dtype=dtype)
                src = torch.tensor([1, float('nan'), -float('inf'), -float('inf'), 2, float('inf')], device=device, dtype=dtype)
                idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
                input.scatter_reduce_(0, idx, src, 'amax', include_self=include_self)
                expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
                if (include_self):
                    expected_result[1] = 0
                self.assertEqual(input, expected_result)


    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
    def test_scatter_reduce_amin(self, device, dtype):
        for include_self in (True, False):
            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
                                    is_scalar=False, reduction='amin', unique_indices=False,
                                    include_self=include_self)
            # simple test for nan/inf propagation
            if (dtype.is_floating_point):
                input = torch.zeros(3, device=device, dtype=dtype)
                src = torch.tensor([1, float('nan'), -2, -float('inf'), float('inf'), float('inf')], device=device, dtype=dtype)
                idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
                input.scatter_reduce_(0, idx, src, 'amin', include_self=include_self)
                expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
                if (include_self):
                    expected_result[2] = 0
                self.assertEqual(input, expected_result)
示例#3
0
class TestModule(TestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True
    precision = 1e-5
    rel_tol = 1e-5

    def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
        # Check device placement and dtype for created parameters and buffers.
        # Only verify floating point dtypes since that's what the kwarg or methods
        # such as `float()` applies to.
        if not isinstance(device, torch.device):
            device = torch.device(device)

        def _check_module(items, name, device=device, dtype=dtype):
            for item_name, item in items:
                self.assertEqual(
                    item.device, device,
                    f'{name} {item_name} is on device {item.device} instead of the expected device {device}'
                )
                if item.dtype.is_floating_point:
                    self.assertEqual(
                        item.dtype, dtype,
                        f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}'
                    )

        _check_module(module.named_parameters(), "Parameter")
        _check_module(module.named_buffers(), "Buffer")

    @skipIfMps  # the test doesn't work on MPS as double types are not supported
    @modules(module_db)
    def test_forward(self, device, dtype, module_info):
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=False)
        dtype_to_method_caller = {
            torch.float32: methodcaller("float"),
            torch.float64: methodcaller("double"),
        }
        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            with freeze_rng_state():
                # === Instantiate the module. ===
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
                m = module_cls(*args, **kwargs)
                m.to(device).to(dtype)

                # === Do forward pass. ===
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
                outputs = m(*args, **kwargs)

                # === Compare outputs to a reference if one is specified. ===
                # TODO: Handle precision
                reference_fn = module_input.reference_fn
                if reference_fn is not None:
                    ref_outputs = reference_fn(m, *args, **kwargs)
                    self.assertEqual(outputs, ref_outputs)

                # === Use the method call and verify the parameters and buffers ===
                if dtype in dtype_to_method_caller:
                    dtype_to_method_caller[dtype](m)
                    m(*args, **kwargs)
                    self._assert_module_parameters_and_buffer_are(
                        m, device, dtype)

    # Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
    # They should be applied to any created parameters and buffers.
    @modules(module_db)
    def test_factory_kwargs(self, device, dtype, module_info):
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=False)
        for module_input in module_inputs:
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs

            # Check if this module creates parameters or registers buffers.
            # The mock magic here passes through to the real Parameter / register_buffer
            # logic and is only used to check call inputs.
            module_creates_params_or_buffers = False
            parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
            with patch.object(torch.nn.Parameter, '__new__', parameter_new):
                register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
                with patch.object(torch.nn.Module, 'register_buffer',
                                  register_buffer):
                    m = module_cls(*args, **kwargs)

                    # Check if a parameter or buffer was created with a tensor not passed to the constructor.
                    constructor_tensors = get_tensors_from(args, kwargs)
                    for mock in [parameter_new.mock, register_buffer.mock]:
                        for call_args, call_kwargs in mock.call_args_list:
                            call_tensors = get_tensors_from(
                                call_args, call_kwargs)
                            if len(
                                    call_tensors
                            ) > 0 and not constructor_tensors.intersection(
                                    call_tensors):
                                module_creates_params_or_buffers = True
                                break

            if not module_creates_params_or_buffers:
                continue

            # Instantiate module with the factory kwargs.
            kwargs.update({
                'device': device,
                'dtype': dtype,
            })

            if issubclass(module_info.module_cls,
                          torch.nn.modules.lazy.LazyModuleMixin):
                # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
                uninit_param_new = mock_wrapper(
                    torch.nn.UninitializedParameter.__new__)
                with patch.object(torch.nn.UninitializedParameter, '__new__',
                                  uninit_param_new):
                    uninit_buffer_new = mock_wrapper(
                        torch.nn.UninitializedBuffer.__new__)
                    with patch.object(torch.nn.UninitializedBuffer, '__new__',
                                      uninit_buffer_new):
                        m = module_cls(*args, **kwargs)
                        uninit_param_new.mock.assert_has_calls([
                            call(device=device, dtype=dtype)
                            for _ in uninit_param_new.mock.mock_calls
                        ])
                        uninit_buffer_new.mock.assert_has_calls([
                            call(device=device, dtype=dtype)
                            for _ in uninit_buffer_new.mock.mock_calls
                        ])
            else:
                # Check device placement and dtype for created parameters and buffers.
                # Only verify floating point dtypes since that's what the kwarg applies to.
                m = module_cls(*args, **kwargs)
                self._assert_module_parameters_and_buffer_are(m, device, dtype)

    @onlyCUDA
    @modules(module_db)
    def test_multiple_device_transfer(self, device, dtype, module_info):
        module_cls = module_info.module_cls
        module_inputs_device = module_info.module_inputs_func(
            module_info, device=device, dtype=dtype, requires_grad=False)
        module_inputs_cpu = module_info.module_inputs_func(module_info,
                                                           device="cpu",
                                                           dtype=dtype,
                                                           requires_grad=False)
        for module_input_device, module_input_cpu in zip(
                module_inputs_device, module_inputs_cpu):
            if module_input_device.forward_input is None:
                continue

            with freeze_rng_state():
                # === Instantiate the module. ===
                args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
                m = module_cls(*args, **kwargs)
                m.to(device).to(dtype)

                # === Do forward pass on GPU ===
                input_device_args = module_input_device.forward_input.args
                input_device_kwargs = module_input_device.forward_input.kwargs
                m(*input_device_args, **input_device_kwargs)
                self._assert_module_parameters_and_buffer_are(m, device, dtype)

                # === Move to CPU ===
                input_cpu_args = module_input_cpu.forward_input.args
                input_cpu_kwargs = module_input_cpu.forward_input.kwargs
                m.cpu()
                m(*input_cpu_args, **input_cpu_kwargs)
                self._assert_module_parameters_and_buffer_are(m, "cpu", dtype)

                # === Move back to GPU and forward pass ===
                m.cuda()
                m(*input_device_args, **input_device_kwargs)
                self._assert_module_parameters_and_buffer_are(m, device, dtype)

                if torch.cuda.device_count() >= 2:
                    # === test cross-GPU transfer works
                    def _to_device1(objs):
                        if isinstance(objs, (tuple, list)):
                            return type(objs)(_to_device1(item)
                                              for item in objs)
                        elif isinstance(objs, dict):
                            return {
                                name: _to_device1(item)
                                for name, item in objs.items()
                            }
                        elif isinstance(objs, torch.Tensor):
                            return objs.cuda(1)
                        else:
                            return objs

                    input_device_1_args = _to_device1(input_device_args)
                    input_device_1_kwargs = _to_device1(input_device_kwargs)

                    m.cuda(1)
                    with torch.cuda.device(1):
                        m(*input_device_1_args, **input_device_1_kwargs)
                    self._assert_module_parameters_and_buffer_are(
                        m, torch.device("cuda:1"), dtype)

    @modules(module_db)
    def test_repr(self, device, dtype, module_info):
        # Test module can be represented with repr and str without errors.
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=False)
        for module_input in module_inputs:
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
            m = module_cls(*args, **kwargs)

            # Check that these methods do not raise errors
            m.__repr__()
            str(m)

    @skipIfMps
    @modules(module_db)
    def test_pickle(self, device, dtype, module_info):
        # Test that module can be pickled and unpickled.
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=False)
        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs

            with freeze_rng_state():
                # === Instantiate the module. ===
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
                m = module_cls(*args, **kwargs)
                m.to(device).to(dtype)

                # === Do forward pass. ===
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
                output = m(*args, **kwargs)

                # === Check unpickled module gives the same output. ===
                with tempfile.TemporaryFile() as f:
                    torch.save(m, f)
                    f.seek(0)
                    m_copy = torch.load(f)
                    output_from_copy = m_copy(*args, **kwargs)
                    self.assertEqual(output, output_from_copy)

    @modules([
        module_info for module_info in module_db
        if 'inplace' in signature(module_info.module_cls).parameters
    ])
    @skipMeta
    def test_check_inplace(self, device, dtype, module_info):
        # Check if the inplace variant of the module gives the same result as the out of place
        # variant.
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=True)
        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            # === Instantiate the module. ===
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
            m_op = module_cls(*args, **kwargs, inplace=False)
            m_op.to(device).to(dtype)
            m_inplace = module_cls(*args, **kwargs, inplace=True)
            m_inplace.to(device).to(dtype)

            # === Inplace modules only supports inplace operations on the first argument ===
            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs

            # ===  Do not allow the first input to be in input_kwargs ===
            forward_sig = signature(m_op).parameters
            self.assertGreaterEqual(len(forward_sig), 1)
            first_param_name = next(iter(forward_sig.items()))
            self.assertNotIn(first_param_name, input_kwargs)

            # === Out of place operation does not write to original tensor ===
            self.assertGreaterEqual(len(input_args), 1)
            input_version = input_args[0]._version
            with freeze_rng_state():
                output_op = m_op(*input_args, **input_kwargs)
            self.assertEqual(input_args[0]._version, input_version)

            # === Check that the inplace operation gives the same result ===
            input_arg_copy = deepcopy(input_args)
            input_arg_clone = tuple(i.clone() for i in input_arg_copy)
            with freeze_rng_state():
                output_ip = m_inplace(*input_arg_clone, **input_kwargs)
            self.assertNotEqual(input_arg_clone[0]._version, input_version)
            self.assertEqual(output_op, output_ip)

            # === Check that the gradients are the same ===
            grad = output_op.data.clone().normal_()
            output_op.backward(grad)
            output_ip.backward(grad)
            self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)

    def _traverse_obj(self, obj, func):
        if isinstance(obj, (tuple, list)):
            return type(obj)(self._traverse_obj(o, func) for o in obj)
        elif isgenerator(obj):
            return tuple(self._traverse_obj(o, func) for o in obj)
        elif isinstance(obj, dict):
            return {
                name: self._traverse_obj(o, func)
                for name, o in obj.items()
            }
        elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
            return func(obj)

    def _retain_grad(self, obj):
        # gradients needs to be retained to check for grad. This is useful when
        # non-leafs are present in the graph.
        def inner_retain_grad(obj):
            if obj.requires_grad:
                obj.retain_grad()

        self._traverse_obj(obj, inner_retain_grad)

    def _get_grads(self, obj):
        def inner_get_grad(obj):
            if obj.requires_grad:
                return obj.grad

        return self._traverse_obj(obj, inner_get_grad)

    def _zero_grad(self, obj):
        def inner_zero_grad(obj):
            if obj.grad is not None:
                obj.grad = None

        self._traverse_obj(obj, inner_zero_grad)

    @skipIfMps
    @modules(module_db)
    def test_non_contiguous_tensors(self, device, dtype, module_info):
        # Check modules work with non-contiguous tensors

        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=True)

        def _make_non_contiguous(obj):
            def inner_make_non_contiguous(obj):
                # Scalar tensors can not be made non-contiguous
                if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
                    return obj

                out = torch.repeat_interleave(obj, 2, dim=-1)
                out = out[..., ::2].detach()
                out.requires_grad = obj.requires_grad
                return out

            return self._traverse_obj(obj, inner_make_non_contiguous)

        def _can_be_noncontiguous(obj):
            if isinstance(obj, (tuple, list)):
                return any(_can_be_noncontiguous(o) for o in obj)
            elif isinstance(obj, dict):
                return any(_can_be_noncontiguous(o) for o in obj.values())
            # scalar tensors can not be non-contiguous
            if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
                return False
            return True

        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
            if not (_can_be_noncontiguous(input_args)
                    or _can_be_noncontiguous(input_kwargs)):
                continue

            # === Instantiate the module. ===
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
            m = module_cls(*args, **kwargs)
            m.to(device).to(dtype)

            self._retain_grad((input_args, input_kwargs))

            # === Forward with default input
            with freeze_rng_state():
                default_output = m(*input_args, **input_kwargs)
                if isinstance(default_output, torch.Tensor):
                    grad_output = default_output.clone().detach_().normal_()
                    default_output.backward(grad_output, retain_graph=True)
                else:
                    grad_output = tuple(
                        self._traverse_obj(
                            o, lambda o: o.clone().detach_().normal_())
                        for o in default_output)
                    flattened_default_output, _ = torch.utils._pytree.tree_flatten(
                        default_output)
                    flattened_grad_output, _ = torch.utils._pytree.tree_flatten(
                        grad_output)
                    for o, g_o in zip(flattened_default_output,
                                      flattened_grad_output):
                        o.backward(g_o, retain_graph=True)

            default_input_args_grad, default_input_kwargs_grad = deepcopy(
                self._get_grads((input_args, input_kwargs)))
            default_param_grad = deepcopy([p.grad for p in m.parameters()])

            # === Construct non-contiguous tensors ===
            nc_input_args, nc_input_kwargs = _make_non_contiguous(
                (input_args, input_kwargs))
            nc_grad_output = _make_non_contiguous(grad_output)

            # === Compare results with non-contiguous and contiguous tensors ===
            inputs = [(input_args, input_kwargs),
                      (nc_input_args, nc_input_kwargs)]
            grads = [grad_output, nc_grad_output]

            for (in_args, in_kwargs), g_out in product(inputs, grads):
                g_out_copy = deepcopy(g_out)
                self._zero_grad((in_args, in_kwargs))
                self._zero_grad(m.parameters())

                with freeze_rng_state():
                    out = m(*in_args, **in_kwargs)
                    if isinstance(out, torch.Tensor):
                        out.backward(g_out_copy, retain_graph=True)
                    else:
                        flattened_out, _ = torch.utils._pytree.tree_flatten(
                            out)
                        flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten(
                            g_out_copy)
                        for o, g_o in zip(flattened_out, flattened_g_out_copy):
                            o.backward(g_o, retain_graph=True)

                input_args_grad, input_kwargs_grad = self._get_grads(
                    (in_args, in_kwargs))
                self.assertEqual(out, default_output)
                self.assertEqual(input_args_grad,
                                 default_input_args_grad,
                                 atol=1e-4,
                                 rtol=0)
                self.assertEqual(input_kwargs_grad,
                                 default_input_kwargs_grad,
                                 atol=1e-4,
                                 rtol=0)

                param_grad = [p.grad for p in m.parameters()]
                self.assertEqual(param_grad, default_param_grad)

    def _test_gradients_helper(self, device, dtype, module_info, check):
        # Check gradients
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=True)
        # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
        gradcheck_nondet_tol = 0.0
        if (torch.device(device).type == 'cuda'
                and torch.backends.cudnn.enabled):
            gradcheck_nondet_tol = module_info.gradcheck_nondet_tol

        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            # === Instantiate the module. ===
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
            m = module_cls(*args, **kwargs)
            m.to(device).to(dtype)

            params = tuple(m.parameters())

            # === Lazy modules need to see an input to initialize params before gradcheck is run. ===
            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
            if issubclass(module_info.module_cls,
                          torch.nn.modules.lazy.LazyModuleMixin):
                with torch.no_grad():
                    m(*input_args, **input_kwargs)

            # === Perform gradient check on the input_args ===
            other_kwargs = {}
            kwarg_tensors = []
            for name, obj in input_kwargs.items():
                if isinstance(obj, torch.Tensor):
                    kwarg_tensors.append((name, obj))
                else:
                    other_kwargs[name] = obj

            grad_input = input_args + params + tuple(
                obj for (_, obj) in kwarg_tensors)

            flat_input, flat_spec = torch.utils._pytree.tree_flatten(
                grad_input)

            def fn_to_gradcheck(*flat_input_and_params):
                input_and_params = torch.utils._pytree.tree_unflatten(
                    flat_input_and_params, flat_spec)
                new_input_args = input_and_params[:len(input_args)]
                kwarg_args = input_and_params[-len(kwarg_tensors):]
                new_kwargs = {
                    name: obj
                    for (name, _), obj in zip(kwarg_tensors, kwarg_args)
                }

                with freeze_rng_state():
                    output = m(*new_input_args, **new_kwargs, **other_kwargs)
                    output_flattened, _ = torch.utils._pytree.tree_flatten(
                        output)
                    return output_flattened

            self.assertTrue(
                check(fn_to_gradcheck,
                      flat_input,
                      nondet_tol=gradcheck_nondet_tol))

    @modules(module_db, allowed_dtypes=[torch.double])
    def test_grad(self, device, dtype, module_info):
        self._test_gradients_helper(device, dtype, module_info, gradcheck)

    @modules([m for m in module_db if m.supports_gradgrad],
             allowed_dtypes=[torch.double])
    def test_gradgrad(self, device, dtype, module_info):
        self._test_gradients_helper(device, dtype, module_info, gradgradcheck)

    @onlyCUDA
    @toleranceOverride({
        torch.float32: tol(5e-2, 0),
        torch.float64: tol(4e-4, 0)
    })
    @modules(module_db)
    def test_cpu_gpu_parity(self, device, dtype, module_info):
        # Test cpu and gpu results are the same
        module_cls = module_info.module_cls
        module_inputs_cpu = module_info.module_inputs_func(module_info,
                                                           device="cpu",
                                                           dtype=dtype,
                                                           requires_grad=True)

        def _to_device(obj):
            if isinstance(obj, torch.Tensor):
                res = obj.detach().to(device=device)
                res.requires_grad = obj.requires_grad
                return res
            elif isinstance(obj, tuple):
                return tuple(_to_device(o) for o in obj)
            elif isinstance(obj, dict):
                return {key: _to_device(o) for key, o in obj.items()}
            else:
                return deepcopy(obj)

        for module_input in module_inputs_cpu:

            # === Move input from cpu to device ===
            cpu_forward_args = module_input.forward_input.args
            cpu_forward_kwargs = module_input.forward_input.kwargs

            gpu_forward_args, gpu_forward_kwargs = _to_device(
                (cpu_forward_args, cpu_forward_kwargs))

            self._retain_grad((cpu_forward_args, cpu_forward_kwargs,
                               gpu_forward_args, gpu_forward_kwargs))

            # === Construct module on cpu and gpu ===
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs

            cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
            gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)

            for cpu_p, gpu_p in zip(cpu_module.parameters(),
                                    gpu_module.parameters()):
                gpu_p.data.copy_(cpu_p)

            # === Compare forward output between cpu and gpu ===
            cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
            gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)

            self.assertEqual(cpu_outputs, gpu_outputs)

            # === Run backwards on CPU and GPU and compare results ===
            def check_backward(cpu_output, gpu_output):
                cpu_grad_output = cpu_output.clone().normal_()
                gpu_grad_output = cpu_grad_output.type_as(gpu_output)

                cpu_output.backward(cpu_grad_output, retain_graph=True)
                gpu_output.backward(gpu_grad_output, retain_graph=True)

                cpu_grad_input = self._get_grads(cpu_forward_args)
                gpu_grad_input = self._get_grads(gpu_forward_args)
                self.assertEqual(cpu_grad_input, gpu_grad_input)

                for cpu_p, gpu_p in zip(cpu_module.parameters(),
                                        gpu_module.parameters()):
                    self.assertEqual(cpu_p.grad, gpu_p.grad)

                cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs)
                gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
                self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)

            for _ in range(5):
                if isinstance(cpu_outputs, torch.Tensor):
                    check_backward(cpu_outputs, gpu_outputs)
                else:
                    flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten(
                        cpu_outputs)
                    flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten(
                        gpu_outputs)
                    for cpu_output, gpu_output in zip(flatten_cpu_outputs,
                                                      flatten_gpu_outputs):
                        check_backward(cpu_output, gpu_output)

    @skipIfMps
    @modules(module_db)
    def test_memory_format(self, device, dtype, module_info):
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=False)
        module_memformat_affects_out = module_info.module_memformat_affects_out

        def _get_mem_formats(channels_last=False, channels_last_3d=False):
            if channels_last:
                return ([torch.contiguous_format, torch.channels_last], [
                    torch.preserve_format, torch.contiguous_format,
                    torch.channels_last
                ])
            elif channels_last_3d:
                return ([torch.contiguous_format, torch.channels_last_3d], [
                    torch.preserve_format, torch.contiguous_format,
                    torch.channels_last_3d
                ])
            else:
                return ([torch.contiguous_format],
                        [torch.preserve_format, torch.contiguous_format])

        # Check that at least one Tensor input has dim == n
        def _check_dims(obj, n):
            if isinstance(obj, torch.Tensor):
                return obj.dim() == n
            elif isinstance(obj, (tuple, list)):
                return any(_check_dims(o, n) for o in obj)
            else:
                return False

        # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format
        def _to_mem_format(mem_format, obj):
            def inner_to_mem_format(obj):
                d = obj.dim()
                if ((mem_format == torch.channels_last and d != 4)
                        or (mem_format == torch.channels_last_3d and d != 5)):
                    return obj
                return obj.to(memory_format=mem_format)

            return self._traverse_obj(obj, inner_to_mem_format)

        def _check_out_mem_format(output, input_mem_format, module_mem_format):
            def inner_check_out_mem_format(output):
                d = output.dim()
                if (d == 4 and ((input_mem_format == torch.channels_last) or
                                (module_mem_format == torch.channels_last
                                 and module_memformat_affects_out))):
                    self.assertTrue(
                        output.is_contiguous(
                            memory_format=torch.channels_last))
                elif (d == 5
                      and ((input_mem_format == torch.channels_last_3d) or
                           (module_mem_format == torch.channels_last_3d
                            and module_memformat_affects_out))):
                    self.assertTrue(
                        output.is_contiguous(
                            memory_format=torch.channels_last_3d))
                else:
                    self.assertTrue(output.is_contiguous())

            return self._traverse_obj(output, inner_check_out_mem_format)

        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            supports_channels_last = _check_dims(
                module_input.forward_input.args, 4)
            supports_channels_last_3d = _check_dims(
                module_input.forward_input.args, 5)
            input_mem_formats, module_mem_formats = _get_mem_formats(
                supports_channels_last, supports_channels_last_3d)

            with freeze_rng_state():
                # === Instantiate the module. ===
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs

                m = module_cls(*args, **kwargs)
                m.to(device).to(dtype)

                # === Get output in (contiguous, contiguous) configuration. ===
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
                desired_outputs = m(*args, **kwargs)

                for input_mem_format in input_mem_formats:
                    # === Change memformat of input. ===
                    module_input.forward_input.args = _to_mem_format(
                        input_mem_format, module_input.forward_input.args)
                    module_input.forward_input.kwargs = _to_mem_format(
                        input_mem_format, module_input.forward_input.kwargs)

                    for module_mem_format in module_mem_formats:
                        # === Change memformat of module ===
                        m.to(memory_format=module_mem_format)

                        # === Do forward pass. ===
                        args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
                        outputs = m(*args, **kwargs)

                        # === Compare outputs to (contiguous, contiguous) output. ===
                        if input_mem_format != torch.contiguous_format or module_mem_formats != torch.contiguous_format:
                            self.assertEqual(outputs, desired_outputs)

                        # === Check mem format of output. ===
                        _check_out_mem_format(outputs, input_mem_format,
                                              module_mem_format)
示例#4
0
# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
    ModuleInfo(torch.nn.AvgPool1d,
               module_inputs_func=module_inputs_torch_nn_AvgPool1d),
    ModuleInfo(torch.nn.ELU,
               module_inputs_func=module_inputs_torch_nn_ELU),
    ModuleInfo(torch.nn.L1Loss,
               module_inputs_func=module_inputs_torch_nn_L1Loss),
    ModuleInfo(torch.nn.Linear,
               module_inputs_func=module_inputs_torch_nn_Linear),
    ModuleInfo(torch.nn.Bilinear,
               module_inputs_func=module_inputs_torch_nn_Bilinear,
               decorators=[
                   DecorateInfo(
                       toleranceOverride({
                           torch.float32: tol(atol=1e-4, rtol=1e-4),
                           torch.float64: tol(atol=1e-4, rtol=1e-4)}),
                       'TestModule', 'test_forward', device_type='cpu')
               ]),
    ModuleInfo(torch.nn.NLLLoss,
               module_inputs_func=module_inputs_torch_nn_NLLLoss),
    ModuleInfo(torch.nn.GaussianNLLLoss,
               module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss),
    ModuleInfo(torch.nn.Hardswish,
               module_inputs_func=module_inputs_torch_nn_Hardswish,
               supports_gradgrad=False),
    ModuleInfo(torch.nn.TransformerEncoderLayer,
               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer),
    ModuleInfo(torch.nn.TransformerDecoderLayer,
               module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer),
    ModuleInfo(torch.nn.Transformer,
示例#5
0
     backward_dtypes=floating_types(),
     sample_inputs_func=sample_inputs_i0_i1,
     supports_forward_ad=True,
     supports_fwgrad_bwgrad=True,
 ),
 UnaryUfuncInfo(
     "special.i1",
     aten_name="special_i1",
     ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
     if TEST_SCIPY else None,
     dtypes=all_types_and(torch.bool),
     dtypesIfCUDA=all_types_and(torch.bool),
     sample_inputs_func=sample_inputs_i0_i1,
     decorators=(DecorateInfo(
         toleranceOverride({
             torch.float32: tol(atol=1e-4, rtol=0),
             torch.bool: tol(atol=1e-4, rtol=0),
         })), ),
     skips=(DecorateInfo(
         unittest.skip("Incorrect result!"),
         "TestUnaryUfuncs",
         "test_reference_numerics_large",
         dtypes=(torch.int8, ),
     ), ),
     supports_fwgrad_bwgrad=True,
     supports_forward_ad=True,
 ),
 UnaryUfuncInfo(
     "special.i1e",
     aten_name="special_i1e",
     ref=scipy.special.i1e if TEST_SCIPY else None,
示例#6
0
# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
    ModuleInfo(torch.nn.AvgPool1d,
               module_inputs_func=module_inputs_torch_nn_AvgPool1d),
    ModuleInfo(torch.nn.ELU, module_inputs_func=module_inputs_torch_nn_ELU),
    ModuleInfo(torch.nn.L1Loss,
               module_inputs_func=module_inputs_torch_nn_L1Loss),
    ModuleInfo(torch.nn.Linear,
               module_inputs_func=module_inputs_torch_nn_Linear),
    ModuleInfo(torch.nn.Bilinear,
               module_inputs_func=module_inputs_torch_nn_Bilinear,
               decorators=[
                   DecorateInfo(toleranceOverride({
                       torch.float32:
                       tol(atol=1e-4, rtol=1e-4),
                       torch.float64:
                       tol(atol=1e-4, rtol=1e-4)
                   }),
                                'TestModule',
                                'test_forward',
                                device_type='cpu')
               ]),
    ModuleInfo(torch.nn.NLLLoss,
               module_inputs_func=module_inputs_torch_nn_NLLLoss),
    ModuleInfo(torch.nn.Hardswish,
               module_inputs_func=module_inputs_torch_nn_Hardswish,
               supports_gradgrad=False),
    ModuleInfo(
        torch.nn.TransformerEncoderLayer,
        module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
示例#7
0
         "TestNormalizeOperators",
         "test_normalize_operator_exhaustive",
     ),
     # FIXME: sum reduces all dimensions when dim=[]
     DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
     DecorateInfo(
         unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
     ),
     # RuntimeError: undefined value tensor
     DecorateInfo(
         unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
     ),
 ),
 decorators=[
     DecorateInfo(
         toleranceOverride({torch.bfloat16: tol(atol=1e-03, rtol=1e-03)}),
         "TestReductions",
         "test_reference_masked",
     ),
     DecorateInfo(
         toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
         "TestReductions",
         "test_reference_masked",
     ),
     DecorateInfo(
         toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
         "TestReductions",
         "test_ref_small_input",
     ),
 ],
 sample_inputs_func=sample_inputs_masked_reduction,
class TestScatterGather(TestCase):
    # Fills an index tensor with valid indices
    def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
        for i in range(1 if dim == 0 else m):
            for j in range(1 if dim == 1 else n):
                for k in range(1 if dim == 2 else o):
                    ii = [i, j, k]
                    ii[dim] = slice(0, idx.size(dim) + 1)
                    idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]

    @dtypes(torch.float32, torch.complex64)
    def test_gather(self, device, dtype):
        m, n, o = random.randint(10, 20), random.randint(10,
                                                         20), random.randint(
                                                             10, 20)
        elems_per_row = random.randint(1, 10)
        dim = random.randrange(3)

        src = make_tensor((m, n, o), device=device, dtype=dtype)
        idx_size = [m, n, o]
        idx_size[dim] = elems_per_row
        idx = make_tensor(idx_size, device=device, dtype=torch.long)
        self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)

        actual = torch.gather(src, dim, idx)
        expected = torch.zeros(idx_size, device=device, dtype=dtype)
        for i in range(idx_size[0]):
            for j in range(idx_size[1]):
                for k in range(idx_size[2]):
                    ii = [i, j, k]
                    ii[dim] = idx[i, j, k]
                    expected[i, j, k] = src[tuple(ii)]
        self.assertEqual(actual, expected, atol=0, rtol=0)

        # Guarded because torch.max isn't defined for complex types
        if not dtype.is_complex:
            src = make_tensor((3, 4, 5), device=device, dtype=dtype)
            expected, idx = src.max(2, True)
            actual = torch.gather(src, 2, idx)
            self.assertEqual(actual, expected, atol=0, rtol=0)

    @dtypes(torch.bool)
    def test_gather_bool(self, device, dtype):
        src = torch.tensor(((False, True), (True, True)),
                           device=device,
                           dtype=dtype)
        idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
        actual = torch.gather(src, 1, idx)
        expected = torch.tensor(((False, False), (True, True)),
                                device=device,
                                dtype=dtype)
        self.assertEqual(actual, expected, atol=0, rtol=0)

    def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction):
        m, n, o = random.randint(10, 20), random.randint(10,
                                                         20), random.randint(
                                                             10, 20)
        elems_per_row = random.randint(1, 10)
        dim = random.randrange(3)

        idx_size = [m, n, o]
        idx_size[dim] = elems_per_row
        idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
        self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)

        if is_scalar:
            src = random.random()
        else:
            src_size = [random.randint(1, 5) + s for s in idx_size]
            src = make_tensor(tuple(src_size), device=device, dtype=dtype)

        base = make_tensor((m, n, o), device=device, dtype=dtype)
        if reduction is not None:
            actual = fn(base.clone(), dim, idx, src, reduce=reduction)
        else:
            actual = fn(base.clone(), dim, idx, src)

        expected = base.clone()
        for i in range(idx_size[0]):
            for j in range(idx_size[1]):
                for k in range(idx_size[2]):
                    ii = [i, j, k]
                    ii[dim] = idx[i, j, k]
                    if fn is torch.Tensor.scatter_add_:
                        expected[tuple(ii)] += src[i, j, k]
                    else:
                        # method may be 'scatter_' or 'scatter'
                        # both might have a reduction argument
                        value = src if is_scalar else src[i, j, k]

                        if reduction == "add":
                            expected[tuple(ii)] += value
                        elif reduction == "multiply":
                            expected[tuple(ii)] *= value
                        else:
                            expected[tuple(ii)] = value

        self.assertEqual(actual, expected, atol=0, rtol=0)

        # Tests empty index
        dst = make_tensor((2, 2), device=device, dtype=dtype)
        idx = torch.tensor((), device=device, dtype=torch.long)
        src = make_tensor((2, 2), device=device, dtype=dtype)
        if reduction is not None:
            actual = fn(dst, 0, idx, src, reduce=reduction)
        else:
            actual = fn(dst, 0, idx, src)
        self.assertEqual(actual, dst, atol=0, rtol=0)

    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter_(self, device, dtype):
        self._test_scatter_base(torch.Tensor.scatter_,
                                device=device,
                                dtype=dtype,
                                is_scalar=False,
                                reduction=None)

    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter__scalar(self, device, dtype):
        self._test_scatter_base(torch.Tensor.scatter_,
                                device=device,
                                dtype=dtype,
                                is_scalar=True,
                                reduction=None)

    # FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
    @toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
    @dtypesIfCUDA(torch.float16, torch.float32)
    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter__reductions(self, device, dtype):
        for reduction in ("add", "multiply"):
            self._test_scatter_base(torch.Tensor.scatter_,
                                    device=device,
                                    dtype=dtype,
                                    is_scalar=False,
                                    reduction=reduction)
            self._test_scatter_base(torch.Tensor.scatter_,
                                    device=device,
                                    dtype=dtype,
                                    is_scalar=True,
                                    reduction=reduction)

    @dtypes(torch.float16, torch.float32, torch.complex64)
    def test_scatter_add_(self, device, dtype):
        self._test_scatter_base(torch.Tensor.scatter_add_,
                                device=device,
                                dtype=dtype,
                                is_scalar=False,
                                reduction=None)

    @dtypes(torch.float32)
    def test_scatter_add_mult_index_base(self, device, dtype):
        m, n = 30, 40
        idx = torch.zeros(m, n, device=device, dtype=torch.long)
        src = torch.ones(m, n, device=device, dtype=dtype)
        res0 = torch.zeros(m, n, device=device,
                           dtype=dtype).scatter_add_(0, idx, src)
        res1 = torch.zeros(m, n, device=device,
                           dtype=dtype).scatter_add_(1, idx, src)

        self.assertEqual(res0[0, :],
                         m * torch.ones(n, device=device, dtype=dtype),
                         atol=0,
                         rtol=0)
        self.assertEqual(res1[:, 0],
                         n * torch.ones(m, device=device, dtype=dtype),
                         atol=0,
                         rtol=0)