class TestPythonJiterator(TestCase): @skipCUDAIfRocm @parametrize("shape_strides", [ (([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous ]) @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), all_types_and_complex_and(torch.half, torch.bfloat16))) def test_all_dtype_contiguous(self, device, dtypes, shape_strides): a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) a = a_buffer.as_strided(*shape_strides[0]) b = b_buffer.as_strided(*shape_strides[1]) expected = ref_fn(a, b) result = jitted_fn(a, b) self.assertEqual(expected, result) @skipCUDAIfRocm # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details @skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to " "compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.") @parametrize("shape_strides", [ (([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous ]) @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), all_types_and_complex_and(torch.half, torch.bfloat16))) def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides): a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) a = a_buffer.as_strided(*shape_strides[0]) b = b_buffer.as_strided(*shape_strides[1]) expected = ref_fn(a, b) result = jitted_fn(a, b) self.assertEqual(expected, result) @skipCUDAIfRocm @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16) @parametrize("alpha", [-1, 2.0, None]) @parametrize("beta", [3, -4.2, None]) @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)}) def test_extra_args(self, device, dtype, alpha, beta): a = torch.rand(3, device=device).mul(10).type(dtype) b = torch.rand(3, device=device).mul(10).type(dtype) extra_args = {} if alpha is not None: extra_args["alpha"] = alpha if beta is not None: extra_args["beta"] = beta expected = ref_fn(a, b, **extra_args) result = jitted_fn(a, b, **extra_args) self.assertEqual(expected, result) @skipCUDAIfRocm @parametrize("is_train", [True, False]) def test_bool_extra_args(self, device, is_train): code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }" jitted_fn = create_jit_fn(code_string, is_train=False) def ref_fn(x, mask, is_train): return x * mask if is_train else x a = torch.rand(3, device=device) b = torch.rand(3, device=device) expected = ref_fn(a, b, is_train=is_train) result = jitted_fn(a, b, is_train=is_train) self.assertEqual(expected, result) @skipCUDAIfRocm def test_multiple_functors(self, device): code_string = ''' template <typename T> T fn(T x, T mask) { return x * mask; } template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; } ''' jitted_fn = create_jit_fn(code_string) def ref_fn(x, mask, y): return x * mask + y a = torch.rand(3, device=device) b = torch.rand(3, device=device) c = torch.rand(3, device=device) expected = ref_fn(a, b, c) result = jitted_fn(a, b, c) self.assertEqual(expected, result) @skipCUDAIfRocm @parametrize("num_inputs", [1, 5, 8]) def test_various_num_inputs(self, num_inputs): inputs = [] for i in range(num_inputs): inputs.append(torch.rand(3, device='cuda').mul(10)) input_string = ",".join([f"T i{i}" for i in range(num_inputs)]) function_body = "+".join([f"i{i}" for i in range(num_inputs)]) code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}" jitted_fn = create_jit_fn(code_string) def ref_fn(*inputs): return torch.sum(torch.stack(inputs), dim=0) expected = ref_fn(*inputs) result = jitted_fn(*inputs) self.assertEqual(expected, result) @skipCUDAIfRocm @parametrize("num_outputs", [1, 4, 8]) def test_various_num_outputs(self, num_outputs): input = torch.rand(3, device='cuda') output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)]) function_body = "" for i in range(num_outputs): function_body += f"out{i} = input + {i};\n" code_string = f"template <typename T> T my_kernel(T input, {output_string}) {{ {function_body} }}" jitted_fn = create_multi_output_jit_fn(code_string, num_outputs) def ref_fn(input): outputs = [] for i in range(num_outputs): outputs.append(input + i) if num_outputs == 1: return outputs[0] return tuple(outputs) expected = ref_fn(input) result = jitted_fn(input) for i in range(num_outputs): self.assertEqual(expected[i], result[i]) @skipCUDAIfRocm @parametrize("code_string", [ "template <typename T> T my _kernel(T x) { return x; }", "template <typename T> Tmy_kernel(T x) { return x; }", ]) def test_invalid_function_name(self, code_string): with self.assertRaises(Exception): jitted_fn = create_jit_fn(code_string)
class TestScatterGather(TestCase): # Fills an index tensor with valid indices def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o, unique_indices=True): for i in range(1 if dim == 0 else m): for j in range(1 if dim == 1 else n): for k in range(1 if dim == 2 else o): ii = [i, j, k] ii[dim] = slice(0, idx.size(dim) + 1) if unique_indices: idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] else: idx[tuple(ii)] = torch.randint(dim_size, (elems_per_row,)) @dtypes(torch.float32, torch.complex64) def test_gather(self, device, dtype): m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) elems_per_row = random.randint(1, 10) dim = random.randrange(3) src = make_tensor((m, n, o), device=device, dtype=dtype) idx_size = [m, n, o] idx_size[dim] = elems_per_row idx = make_tensor(idx_size, device=device, dtype=torch.long) self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o) actual = torch.gather(src, dim, idx) expected = torch.zeros(idx_size, device=device, dtype=dtype) for i in range(idx_size[0]): for j in range(idx_size[1]): for k in range(idx_size[2]): ii = [i, j, k] ii[dim] = idx[i, j, k] expected[i, j, k] = src[tuple(ii)] self.assertEqual(actual, expected, atol=0, rtol=0) # Guarded because torch.max isn't defined for complex types if not dtype.is_complex: src = make_tensor((3, 4, 5), device=device, dtype=dtype) expected, idx = src.max(2, True) actual = torch.gather(src, 2, idx) self.assertEqual(actual, expected, atol=0, rtol=0) @dtypes(torch.bool) def test_gather_bool(self, device, dtype): src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype) idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) actual = torch.gather(src, 1, idx) expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype) self.assertEqual(actual, expected, atol=0, rtol=0) @parametrize("sparse_grad", [False, True]) @dtypes(torch.float32, torch.float64) def test_gather_backward_with_empty_index_tensor(self, device, dtype, sparse_grad): dim = -1 input = torch.rand([10, 5], dtype=dtype, device=device, requires_grad=True) index = torch.randint(0, 2, [3, 0], dtype=torch.int64, device=device) res = torch.gather(input, dim, index, sparse_grad=sparse_grad) res.sum().backward() grad = input.grad.to_dense() if sparse_grad else input.grad expected_grad = torch.zeros_like(input, requires_grad=False) self.assertEqual(grad, expected_grad, atol=0, rtol=0) def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction, unique_indices=True, include_self=True): m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) elems_per_row = random.randint(1, 10) dim = random.randrange(3) idx_size = [m, n, o] idx_size[dim] = elems_per_row idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long) self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o, unique_indices) if is_scalar: src = random.random() else: src_size = [random.randint(1, 5) + s for s in idx_size] src = make_tensor(tuple(src_size), device=device, dtype=dtype) base = make_tensor((m, n, o), device=device, dtype=dtype) if reduction is not None: if fn is torch.Tensor.scatter_reduce_: actual = fn(base.clone(), dim, idx, src, reduce=reduction, include_self=include_self) else: actual = fn(base.clone(), dim, idx, src, reduce=reduction) else: actual = fn(base.clone(), dim, idx, src) expected = base.clone() counts = torch.zeros(base.shape, dtype=torch.long, device=device) + include_self for i in range(idx_size[0]): for j in range(idx_size[1]): for k in range(idx_size[2]): ii = [i, j, k] ii[dim] = idx[i, j, k] if fn is torch.Tensor.scatter_add_: expected[tuple(ii)] += src[i, j, k] else: # method may be 'scatter_', 'scatter', 'scatter_reduce' # or 'scatter_reduce_', the former two might have a reduction argument # while the latter two always do value = src if is_scalar else src[i, j, k] if ((not include_self) and counts[tuple(ii)] == 0): expected[tuple(ii)] = value else: if reduction == "add" or reduction == "sum": expected[tuple(ii)] += value elif reduction == "multiply" or reduction == "prod": expected[tuple(ii)] *= value elif reduction == "amax": expected[tuple(ii)] = max(expected[tuple(ii)], value) elif reduction == "amin": expected[tuple(ii)] = min(expected[tuple(ii)], value) elif reduction == "mean": expected[tuple(ii)] += value else: expected[tuple(ii)] = value counts[tuple(ii)] += 1 if (reduction == "mean"): counts.masked_fill_(counts == 0, 1) if (dtype.is_floating_point or dtype.is_complex): expected /= counts else: expected.div_(counts, rounding_mode="floor") self.assertEqual(actual, expected, atol=0, rtol=0) # Tests empty index dst = make_tensor((2, 2), device=device, dtype=dtype) idx = torch.tensor((), device=device, dtype=torch.long) src = make_tensor((2, 2), device=device, dtype=dtype) if reduction is not None: actual = fn(dst, 0, idx, src, reduce=reduction) else: actual = fn(dst, 0, idx, src) self.assertEqual(actual, dst, atol=0, rtol=0) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter_(self, device, dtype): self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=False, reduction=None) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter__scalar(self, device, dtype): self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=True, reduction=None) # FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat' @toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)}) @dtypesIfCUDA(torch.float16, torch.float32) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter__reductions(self, device, dtype): for reduction in ("add", "multiply"): self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=False, reduction=reduction) self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=True, reduction=reduction) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter_add_(self, device, dtype): self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype, is_scalar=False, reduction=None) @dtypes(torch.float32) def test_scatter_add_mult_index_base(self, device, dtype): m, n = 30, 40 idx = torch.zeros(m, n, device=device, dtype=torch.long) src = torch.ones(m, n, device=device, dtype=dtype) res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src) res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src) self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0) self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0) # FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False)) def test_scatter_reduce_sum(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, is_scalar=False, reduction='sum', unique_indices=False, include_self=include_self) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True)) @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_prod(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, is_scalar=False, reduction='prod', unique_indices=False, include_self=include_self) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False)) @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_mean(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, is_scalar=False, reduction='mean', unique_indices=False, include_self=include_self) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_amax(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, is_scalar=False, reduction='amax', unique_indices=False, include_self=include_self) # simple test for nan/inf propagation if (dtype.is_floating_point): input = torch.zeros(3, device=device, dtype=dtype) src = torch.tensor([1, float('nan'), -float('inf'), -float('inf'), 2, float('inf')], device=device, dtype=dtype) idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device) input.scatter_reduce_(0, idx, src, 'amax', include_self=include_self) expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype) if (include_self): expected_result[1] = 0 self.assertEqual(input, expected_result) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_amin(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, is_scalar=False, reduction='amin', unique_indices=False, include_self=include_self) # simple test for nan/inf propagation if (dtype.is_floating_point): input = torch.zeros(3, device=device, dtype=dtype) src = torch.tensor([1, float('nan'), -2, -float('inf'), float('inf'), float('inf')], device=device, dtype=dtype) idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device) input.scatter_reduce_(0, idx, src, 'amin', include_self=include_self) expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype) if (include_self): expected_result[2] = 0 self.assertEqual(input, expected_result)
class TestModule(TestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True precision = 1e-5 rel_tol = 1e-5 def _assert_module_parameters_and_buffer_are(self, module, device, dtype): # Check device placement and dtype for created parameters and buffers. # Only verify floating point dtypes since that's what the kwarg or methods # such as `float()` applies to. if not isinstance(device, torch.device): device = torch.device(device) def _check_module(items, name, device=device, dtype=dtype): for item_name, item in items: self.assertEqual( item.device, device, f'{name} {item_name} is on device {item.device} instead of the expected device {device}' ) if item.dtype.is_floating_point: self.assertEqual( item.dtype, dtype, f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}' ) _check_module(module.named_parameters(), "Parameter") _check_module(module.named_buffers(), "Buffer") @skipIfMps # the test doesn't work on MPS as double types are not supported @modules(module_db) def test_forward(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) dtype_to_method_caller = { torch.float32: methodcaller("float"), torch.float64: methodcaller("double"), } for module_input in module_inputs: if module_input.forward_input is None: continue with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs outputs = m(*args, **kwargs) # === Compare outputs to a reference if one is specified. === # TODO: Handle precision reference_fn = module_input.reference_fn if reference_fn is not None: ref_outputs = reference_fn(m, *args, **kwargs) self.assertEqual(outputs, ref_outputs) # === Use the method call and verify the parameters and buffers === if dtype in dtype_to_method_caller: dtype_to_method_caller[dtype](m) m(*args, **kwargs) self._assert_module_parameters_and_buffer_are( m, device, dtype) # Tests passing factory kwargs (e.g. device / dtype) during module instantiation. # They should be applied to any created parameters and buffers. @modules(module_db) def test_factory_kwargs(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs # Check if this module creates parameters or registers buffers. # The mock magic here passes through to the real Parameter / register_buffer # logic and is only used to check call inputs. module_creates_params_or_buffers = False parameter_new = mock_wrapper(torch.nn.Parameter.__new__) with patch.object(torch.nn.Parameter, '__new__', parameter_new): register_buffer = mock_wrapper(torch.nn.Module.register_buffer) with patch.object(torch.nn.Module, 'register_buffer', register_buffer): m = module_cls(*args, **kwargs) # Check if a parameter or buffer was created with a tensor not passed to the constructor. constructor_tensors = get_tensors_from(args, kwargs) for mock in [parameter_new.mock, register_buffer.mock]: for call_args, call_kwargs in mock.call_args_list: call_tensors = get_tensors_from( call_args, call_kwargs) if len( call_tensors ) > 0 and not constructor_tensors.intersection( call_tensors): module_creates_params_or_buffers = True break if not module_creates_params_or_buffers: continue # Instantiate module with the factory kwargs. kwargs.update({ 'device': device, 'dtype': dtype, }) if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers. uninit_param_new = mock_wrapper( torch.nn.UninitializedParameter.__new__) with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new): uninit_buffer_new = mock_wrapper( torch.nn.UninitializedBuffer.__new__) with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new): m = module_cls(*args, **kwargs) uninit_param_new.mock.assert_has_calls([ call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls ]) uninit_buffer_new.mock.assert_has_calls([ call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls ]) else: # Check device placement and dtype for created parameters and buffers. # Only verify floating point dtypes since that's what the kwarg applies to. m = module_cls(*args, **kwargs) self._assert_module_parameters_and_buffer_are(m, device, dtype) @onlyCUDA @modules(module_db) def test_multiple_device_transfer(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs_device = module_info.module_inputs_func( module_info, device=device, dtype=dtype, requires_grad=False) module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, requires_grad=False) for module_input_device, module_input_cpu in zip( module_inputs_device, module_inputs_cpu): if module_input_device.forward_input is None: continue with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass on GPU === input_device_args = module_input_device.forward_input.args input_device_kwargs = module_input_device.forward_input.kwargs m(*input_device_args, **input_device_kwargs) self._assert_module_parameters_and_buffer_are(m, device, dtype) # === Move to CPU === input_cpu_args = module_input_cpu.forward_input.args input_cpu_kwargs = module_input_cpu.forward_input.kwargs m.cpu() m(*input_cpu_args, **input_cpu_kwargs) self._assert_module_parameters_and_buffer_are(m, "cpu", dtype) # === Move back to GPU and forward pass === m.cuda() m(*input_device_args, **input_device_kwargs) self._assert_module_parameters_and_buffer_are(m, device, dtype) if torch.cuda.device_count() >= 2: # === test cross-GPU transfer works def _to_device1(objs): if isinstance(objs, (tuple, list)): return type(objs)(_to_device1(item) for item in objs) elif isinstance(objs, dict): return { name: _to_device1(item) for name, item in objs.items() } elif isinstance(objs, torch.Tensor): return objs.cuda(1) else: return objs input_device_1_args = _to_device1(input_device_args) input_device_1_kwargs = _to_device1(input_device_kwargs) m.cuda(1) with torch.cuda.device(1): m(*input_device_1_args, **input_device_1_kwargs) self._assert_module_parameters_and_buffer_are( m, torch.device("cuda:1"), dtype) @modules(module_db) def test_repr(self, device, dtype, module_info): # Test module can be represented with repr and str without errors. module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) # Check that these methods do not raise errors m.__repr__() str(m) @skipIfMps @modules(module_db) def test_pickle(self, device, dtype, module_info): # Test that module can be pickled and unpickled. module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: if module_input.forward_input is None: continue args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs output = m(*args, **kwargs) # === Check unpickled module gives the same output. === with tempfile.TemporaryFile() as f: torch.save(m, f) f.seek(0) m_copy = torch.load(f) output_from_copy = m_copy(*args, **kwargs) self.assertEqual(output, output_from_copy) @modules([ module_info for module_info in module_db if 'inplace' in signature(module_info.module_cls).parameters ]) @skipMeta def test_check_inplace(self, device, dtype, module_info): # Check if the inplace variant of the module gives the same result as the out of place # variant. module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=True) for module_input in module_inputs: if module_input.forward_input is None: continue # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m_op = module_cls(*args, **kwargs, inplace=False) m_op.to(device).to(dtype) m_inplace = module_cls(*args, **kwargs, inplace=True) m_inplace.to(device).to(dtype) # === Inplace modules only supports inplace operations on the first argument === input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs # === Do not allow the first input to be in input_kwargs === forward_sig = signature(m_op).parameters self.assertGreaterEqual(len(forward_sig), 1) first_param_name = next(iter(forward_sig.items())) self.assertNotIn(first_param_name, input_kwargs) # === Out of place operation does not write to original tensor === self.assertGreaterEqual(len(input_args), 1) input_version = input_args[0]._version with freeze_rng_state(): output_op = m_op(*input_args, **input_kwargs) self.assertEqual(input_args[0]._version, input_version) # === Check that the inplace operation gives the same result === input_arg_copy = deepcopy(input_args) input_arg_clone = tuple(i.clone() for i in input_arg_copy) with freeze_rng_state(): output_ip = m_inplace(*input_arg_clone, **input_kwargs) self.assertNotEqual(input_arg_clone[0]._version, input_version) self.assertEqual(output_op, output_ip) # === Check that the gradients are the same === grad = output_op.data.clone().normal_() output_op.backward(grad) output_ip.backward(grad) self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) def _traverse_obj(self, obj, func): if isinstance(obj, (tuple, list)): return type(obj)(self._traverse_obj(o, func) for o in obj) elif isgenerator(obj): return tuple(self._traverse_obj(o, func) for o in obj) elif isinstance(obj, dict): return { name: self._traverse_obj(o, func) for name, o in obj.items() } elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)): return func(obj) def _retain_grad(self, obj): # gradients needs to be retained to check for grad. This is useful when # non-leafs are present in the graph. def inner_retain_grad(obj): if obj.requires_grad: obj.retain_grad() self._traverse_obj(obj, inner_retain_grad) def _get_grads(self, obj): def inner_get_grad(obj): if obj.requires_grad: return obj.grad return self._traverse_obj(obj, inner_get_grad) def _zero_grad(self, obj): def inner_zero_grad(obj): if obj.grad is not None: obj.grad = None self._traverse_obj(obj, inner_zero_grad) @skipIfMps @modules(module_db) def test_non_contiguous_tensors(self, device, dtype, module_info): # Check modules work with non-contiguous tensors module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=True) def _make_non_contiguous(obj): def inner_make_non_contiguous(obj): # Scalar tensors can not be made non-contiguous if not isinstance(obj, torch.Tensor) or obj.dim() == 0: return obj out = torch.repeat_interleave(obj, 2, dim=-1) out = out[..., ::2].detach() out.requires_grad = obj.requires_grad return out return self._traverse_obj(obj, inner_make_non_contiguous) def _can_be_noncontiguous(obj): if isinstance(obj, (tuple, list)): return any(_can_be_noncontiguous(o) for o in obj) elif isinstance(obj, dict): return any(_can_be_noncontiguous(o) for o in obj.values()) # scalar tensors can not be non-contiguous if not isinstance(obj, torch.Tensor) or obj.dim() == 0: return False return True for module_input in module_inputs: if module_input.forward_input is None: continue input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)): continue # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) self._retain_grad((input_args, input_kwargs)) # === Forward with default input with freeze_rng_state(): default_output = m(*input_args, **input_kwargs) if isinstance(default_output, torch.Tensor): grad_output = default_output.clone().detach_().normal_() default_output.backward(grad_output, retain_graph=True) else: grad_output = tuple( self._traverse_obj( o, lambda o: o.clone().detach_().normal_()) for o in default_output) flattened_default_output, _ = torch.utils._pytree.tree_flatten( default_output) flattened_grad_output, _ = torch.utils._pytree.tree_flatten( grad_output) for o, g_o in zip(flattened_default_output, flattened_grad_output): o.backward(g_o, retain_graph=True) default_input_args_grad, default_input_kwargs_grad = deepcopy( self._get_grads((input_args, input_kwargs))) default_param_grad = deepcopy([p.grad for p in m.parameters()]) # === Construct non-contiguous tensors === nc_input_args, nc_input_kwargs = _make_non_contiguous( (input_args, input_kwargs)) nc_grad_output = _make_non_contiguous(grad_output) # === Compare results with non-contiguous and contiguous tensors === inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)] grads = [grad_output, nc_grad_output] for (in_args, in_kwargs), g_out in product(inputs, grads): g_out_copy = deepcopy(g_out) self._zero_grad((in_args, in_kwargs)) self._zero_grad(m.parameters()) with freeze_rng_state(): out = m(*in_args, **in_kwargs) if isinstance(out, torch.Tensor): out.backward(g_out_copy, retain_graph=True) else: flattened_out, _ = torch.utils._pytree.tree_flatten( out) flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten( g_out_copy) for o, g_o in zip(flattened_out, flattened_g_out_copy): o.backward(g_o, retain_graph=True) input_args_grad, input_kwargs_grad = self._get_grads( (in_args, in_kwargs)) self.assertEqual(out, default_output) self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0) self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0) param_grad = [p.grad for p in m.parameters()] self.assertEqual(param_grad, default_param_grad) def _test_gradients_helper(self, device, dtype, module_info, check): # Check gradients module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=True) # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled gradcheck_nondet_tol = 0.0 if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled): gradcheck_nondet_tol = module_info.gradcheck_nondet_tol for module_input in module_inputs: if module_input.forward_input is None: continue # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) params = tuple(m.parameters()) # === Lazy modules need to see an input to initialize params before gradcheck is run. === input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): with torch.no_grad(): m(*input_args, **input_kwargs) # === Perform gradient check on the input_args === other_kwargs = {} kwarg_tensors = [] for name, obj in input_kwargs.items(): if isinstance(obj, torch.Tensor): kwarg_tensors.append((name, obj)) else: other_kwargs[name] = obj grad_input = input_args + params + tuple( obj for (_, obj) in kwarg_tensors) flat_input, flat_spec = torch.utils._pytree.tree_flatten( grad_input) def fn_to_gradcheck(*flat_input_and_params): input_and_params = torch.utils._pytree.tree_unflatten( flat_input_and_params, flat_spec) new_input_args = input_and_params[:len(input_args)] kwarg_args = input_and_params[-len(kwarg_tensors):] new_kwargs = { name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args) } with freeze_rng_state(): output = m(*new_input_args, **new_kwargs, **other_kwargs) output_flattened, _ = torch.utils._pytree.tree_flatten( output) return output_flattened self.assertTrue( check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) @modules(module_db, allowed_dtypes=[torch.double]) def test_grad(self, device, dtype, module_info): self._test_gradients_helper(device, dtype, module_info, gradcheck) @modules([m for m in module_db if m.supports_gradgrad], allowed_dtypes=[torch.double]) def test_gradgrad(self, device, dtype, module_info): self._test_gradients_helper(device, dtype, module_info, gradgradcheck) @onlyCUDA @toleranceOverride({ torch.float32: tol(5e-2, 0), torch.float64: tol(4e-4, 0) }) @modules(module_db) def test_cpu_gpu_parity(self, device, dtype, module_info): # Test cpu and gpu results are the same module_cls = module_info.module_cls module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, requires_grad=True) def _to_device(obj): if isinstance(obj, torch.Tensor): res = obj.detach().to(device=device) res.requires_grad = obj.requires_grad return res elif isinstance(obj, tuple): return tuple(_to_device(o) for o in obj) elif isinstance(obj, dict): return {key: _to_device(o) for key, o in obj.items()} else: return deepcopy(obj) for module_input in module_inputs_cpu: # === Move input from cpu to device === cpu_forward_args = module_input.forward_input.args cpu_forward_kwargs = module_input.forward_input.kwargs gpu_forward_args, gpu_forward_kwargs = _to_device( (cpu_forward_args, cpu_forward_kwargs)) self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs)) # === Construct module on cpu and gpu === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu") gpu_module = module_cls(*args, **kwargs).to(dtype).to(device) for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): gpu_p.data.copy_(cpu_p) # === Compare forward output between cpu and gpu === cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs) gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs) self.assertEqual(cpu_outputs, gpu_outputs) # === Run backwards on CPU and GPU and compare results === def check_backward(cpu_output, gpu_output): cpu_grad_output = cpu_output.clone().normal_() gpu_grad_output = cpu_grad_output.type_as(gpu_output) cpu_output.backward(cpu_grad_output, retain_graph=True) gpu_output.backward(gpu_grad_output, retain_graph=True) cpu_grad_input = self._get_grads(cpu_forward_args) gpu_grad_input = self._get_grads(gpu_forward_args) self.assertEqual(cpu_grad_input, gpu_grad_input) for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): self.assertEqual(cpu_p.grad, gpu_p.grad) cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs) gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs) self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input) for _ in range(5): if isinstance(cpu_outputs, torch.Tensor): check_backward(cpu_outputs, gpu_outputs) else: flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten( cpu_outputs) flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten( gpu_outputs) for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs): check_backward(cpu_output, gpu_output) @skipIfMps @modules(module_db) def test_memory_format(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) module_memformat_affects_out = module_info.module_memformat_affects_out def _get_mem_formats(channels_last=False, channels_last_3d=False): if channels_last: return ([torch.contiguous_format, torch.channels_last], [ torch.preserve_format, torch.contiguous_format, torch.channels_last ]) elif channels_last_3d: return ([torch.contiguous_format, torch.channels_last_3d], [ torch.preserve_format, torch.contiguous_format, torch.channels_last_3d ]) else: return ([torch.contiguous_format], [torch.preserve_format, torch.contiguous_format]) # Check that at least one Tensor input has dim == n def _check_dims(obj, n): if isinstance(obj, torch.Tensor): return obj.dim() == n elif isinstance(obj, (tuple, list)): return any(_check_dims(o, n) for o in obj) else: return False # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format def _to_mem_format(mem_format, obj): def inner_to_mem_format(obj): d = obj.dim() if ((mem_format == torch.channels_last and d != 4) or (mem_format == torch.channels_last_3d and d != 5)): return obj return obj.to(memory_format=mem_format) return self._traverse_obj(obj, inner_to_mem_format) def _check_out_mem_format(output, input_mem_format, module_mem_format): def inner_check_out_mem_format(output): d = output.dim() if (d == 4 and ((input_mem_format == torch.channels_last) or (module_mem_format == torch.channels_last and module_memformat_affects_out))): self.assertTrue( output.is_contiguous( memory_format=torch.channels_last)) elif (d == 5 and ((input_mem_format == torch.channels_last_3d) or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))): self.assertTrue( output.is_contiguous( memory_format=torch.channels_last_3d)) else: self.assertTrue(output.is_contiguous()) return self._traverse_obj(output, inner_check_out_mem_format) for module_input in module_inputs: if module_input.forward_input is None: continue supports_channels_last = _check_dims( module_input.forward_input.args, 4) supports_channels_last_3d = _check_dims( module_input.forward_input.args, 5) input_mem_formats, module_mem_formats = _get_mem_formats( supports_channels_last, supports_channels_last_3d) with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Get output in (contiguous, contiguous) configuration. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs desired_outputs = m(*args, **kwargs) for input_mem_format in input_mem_formats: # === Change memformat of input. === module_input.forward_input.args = _to_mem_format( input_mem_format, module_input.forward_input.args) module_input.forward_input.kwargs = _to_mem_format( input_mem_format, module_input.forward_input.kwargs) for module_mem_format in module_mem_formats: # === Change memformat of module === m.to(memory_format=module_mem_format) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs outputs = m(*args, **kwargs) # === Compare outputs to (contiguous, contiguous) output. === if input_mem_format != torch.contiguous_format or module_mem_formats != torch.contiguous_format: self.assertEqual(outputs, desired_outputs) # === Check mem format of output. === _check_out_mem_format(outputs, input_mem_format, module_mem_format)
# Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ ModuleInfo(torch.nn.AvgPool1d, module_inputs_func=module_inputs_torch_nn_AvgPool1d), ModuleInfo(torch.nn.ELU, module_inputs_func=module_inputs_torch_nn_ELU), ModuleInfo(torch.nn.L1Loss, module_inputs_func=module_inputs_torch_nn_L1Loss), ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear), ModuleInfo(torch.nn.Bilinear, module_inputs_func=module_inputs_torch_nn_Bilinear, decorators=[ DecorateInfo( toleranceOverride({ torch.float32: tol(atol=1e-4, rtol=1e-4), torch.float64: tol(atol=1e-4, rtol=1e-4)}), 'TestModule', 'test_forward', device_type='cpu') ]), ModuleInfo(torch.nn.NLLLoss, module_inputs_func=module_inputs_torch_nn_NLLLoss), ModuleInfo(torch.nn.GaussianNLLLoss, module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss), ModuleInfo(torch.nn.Hardswish, module_inputs_func=module_inputs_torch_nn_Hardswish, supports_gradgrad=False), ModuleInfo(torch.nn.TransformerEncoderLayer, module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer), ModuleInfo(torch.nn.TransformerDecoderLayer, module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer), ModuleInfo(torch.nn.Transformer,
backward_dtypes=floating_types(), sample_inputs_func=sample_inputs_i0_i1, supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), UnaryUfuncInfo( "special.i1", aten_name="special_i1", ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1) if TEST_SCIPY else None, dtypes=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool), sample_inputs_func=sample_inputs_i0_i1, decorators=(DecorateInfo( toleranceOverride({ torch.float32: tol(atol=1e-4, rtol=0), torch.bool: tol(atol=1e-4, rtol=0), })), ), skips=(DecorateInfo( unittest.skip("Incorrect result!"), "TestUnaryUfuncs", "test_reference_numerics_large", dtypes=(torch.int8, ), ), ), supports_fwgrad_bwgrad=True, supports_forward_ad=True, ), UnaryUfuncInfo( "special.i1e", aten_name="special_i1e", ref=scipy.special.i1e if TEST_SCIPY else None,
# Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ ModuleInfo(torch.nn.AvgPool1d, module_inputs_func=module_inputs_torch_nn_AvgPool1d), ModuleInfo(torch.nn.ELU, module_inputs_func=module_inputs_torch_nn_ELU), ModuleInfo(torch.nn.L1Loss, module_inputs_func=module_inputs_torch_nn_L1Loss), ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear), ModuleInfo(torch.nn.Bilinear, module_inputs_func=module_inputs_torch_nn_Bilinear, decorators=[ DecorateInfo(toleranceOverride({ torch.float32: tol(atol=1e-4, rtol=1e-4), torch.float64: tol(atol=1e-4, rtol=1e-4) }), 'TestModule', 'test_forward', device_type='cpu') ]), ModuleInfo(torch.nn.NLLLoss, module_inputs_func=module_inputs_torch_nn_NLLLoss), ModuleInfo(torch.nn.Hardswish, module_inputs_func=module_inputs_torch_nn_Hardswish, supports_gradgrad=False), ModuleInfo( torch.nn.TransformerEncoderLayer, module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
"TestNormalizeOperators", "test_normalize_operator_exhaustive", ), # FIXME: sum reduces all dimensions when dim=[] DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), DecorateInfo( unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" ), # RuntimeError: undefined value tensor DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), ), decorators=[ DecorateInfo( toleranceOverride({torch.bfloat16: tol(atol=1e-03, rtol=1e-03)}), "TestReductions", "test_reference_masked", ), DecorateInfo( toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), "TestReductions", "test_reference_masked", ), DecorateInfo( toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), "TestReductions", "test_ref_small_input", ), ], sample_inputs_func=sample_inputs_masked_reduction,
class TestScatterGather(TestCase): # Fills an index tensor with valid indices def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o): for i in range(1 if dim == 0 else m): for j in range(1 if dim == 1 else n): for k in range(1 if dim == 2 else o): ii = [i, j, k] ii[dim] = slice(0, idx.size(dim) + 1) idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] @dtypes(torch.float32, torch.complex64) def test_gather(self, device, dtype): m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint( 10, 20) elems_per_row = random.randint(1, 10) dim = random.randrange(3) src = make_tensor((m, n, o), device=device, dtype=dtype) idx_size = [m, n, o] idx_size[dim] = elems_per_row idx = make_tensor(idx_size, device=device, dtype=torch.long) self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o) actual = torch.gather(src, dim, idx) expected = torch.zeros(idx_size, device=device, dtype=dtype) for i in range(idx_size[0]): for j in range(idx_size[1]): for k in range(idx_size[2]): ii = [i, j, k] ii[dim] = idx[i, j, k] expected[i, j, k] = src[tuple(ii)] self.assertEqual(actual, expected, atol=0, rtol=0) # Guarded because torch.max isn't defined for complex types if not dtype.is_complex: src = make_tensor((3, 4, 5), device=device, dtype=dtype) expected, idx = src.max(2, True) actual = torch.gather(src, 2, idx) self.assertEqual(actual, expected, atol=0, rtol=0) @dtypes(torch.bool) def test_gather_bool(self, device, dtype): src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype) idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) actual = torch.gather(src, 1, idx) expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype) self.assertEqual(actual, expected, atol=0, rtol=0) def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction): m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint( 10, 20) elems_per_row = random.randint(1, 10) dim = random.randrange(3) idx_size = [m, n, o] idx_size[dim] = elems_per_row idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long) self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o) if is_scalar: src = random.random() else: src_size = [random.randint(1, 5) + s for s in idx_size] src = make_tensor(tuple(src_size), device=device, dtype=dtype) base = make_tensor((m, n, o), device=device, dtype=dtype) if reduction is not None: actual = fn(base.clone(), dim, idx, src, reduce=reduction) else: actual = fn(base.clone(), dim, idx, src) expected = base.clone() for i in range(idx_size[0]): for j in range(idx_size[1]): for k in range(idx_size[2]): ii = [i, j, k] ii[dim] = idx[i, j, k] if fn is torch.Tensor.scatter_add_: expected[tuple(ii)] += src[i, j, k] else: # method may be 'scatter_' or 'scatter' # both might have a reduction argument value = src if is_scalar else src[i, j, k] if reduction == "add": expected[tuple(ii)] += value elif reduction == "multiply": expected[tuple(ii)] *= value else: expected[tuple(ii)] = value self.assertEqual(actual, expected, atol=0, rtol=0) # Tests empty index dst = make_tensor((2, 2), device=device, dtype=dtype) idx = torch.tensor((), device=device, dtype=torch.long) src = make_tensor((2, 2), device=device, dtype=dtype) if reduction is not None: actual = fn(dst, 0, idx, src, reduce=reduction) else: actual = fn(dst, 0, idx, src) self.assertEqual(actual, dst, atol=0, rtol=0) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter_(self, device, dtype): self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=False, reduction=None) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter__scalar(self, device, dtype): self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=True, reduction=None) # FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat' @toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)}) @dtypesIfCUDA(torch.float16, torch.float32) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter__reductions(self, device, dtype): for reduction in ("add", "multiply"): self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=False, reduction=reduction) self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, is_scalar=True, reduction=reduction) @dtypes(torch.float16, torch.float32, torch.complex64) def test_scatter_add_(self, device, dtype): self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype, is_scalar=False, reduction=None) @dtypes(torch.float32) def test_scatter_add_mult_index_base(self, device, dtype): m, n = 30, 40 idx = torch.zeros(m, n, device=device, dtype=torch.long) src = torch.ones(m, n, device=device, dtype=dtype) res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src) res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src) self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0) self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)