def test_non_contiguous(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: num_elements = torch.randint(100, 10000, (1, )).item() src = torch.rand(num_elements, dtype=torch.float32).to(device) src_stride = torch.randint(2, 8, (1, )).item() src = src[::src_stride] num_elements = src.numel() num_indexes = num_elements * torch.randint(2, 10, (1, )).item() index = torch.randint(0, num_elements, (num_indexes, ), dtype=torch.int32).to(device) value_stride = torch.randint(2, 6, (1, )).item() value = torch.rand(num_indexes * value_stride, dtype=torch.float32).to(device) value = value[::value_stride] assert src.is_contiguous() is False assert index.is_contiguous() assert value.is_contiguous() is False saved = src.clone() k2.index_add(index, value, src) saved = torch.cat([torch.tensor([0]).to(saved), saved]) saved.index_add_(0, index.to(torch.int64) + 1, value) assert torch.allclose(src, saved[1:])
def test_2d_non_contiguous(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: for dtype in [torch.int32, torch.float32, torch.float64]: col_stride = torch.randint(2, 8, (1, )).item() num_rows = torch.randint(10, 1000, (1, )).item() num_cols = torch.randint(10, 1000, (1, )).item() * col_stride src = torch.randint(-1000, 1000, size=(num_rows, num_cols), dtype=dtype, device=device) src = src[:, ::col_stride] num_indexes = num_rows * torch.randint(2, 10, (1, )).item() index = torch.randint(-1, num_rows, size=(num_indexes, ), dtype=torch.int32, device=device) value_stride = torch.randint(2, 8, (1, )).item() value = torch.randint(-1000, 1000, size=(num_indexes, num_cols * value_stride), dtype=dtype, device=device) value = value[:, ::(col_stride * value_stride)] assert src.is_contiguous() is False assert index.is_contiguous() assert value.is_contiguous() is False assert src.dtype == value.dtype == dtype assert index.dtype == torch.int32 assert src.device == value.device == index.device == device saved = src.clone() k2.index_add(index, value, src) saved = torch.cat( [torch.zeros(1, saved.shape[1]).to(saved), saved]) saved.index_add_(0, index.to(torch.int64) + 1, value) assert torch.all(torch.eq(src, saved[1:]))
def test_1d_non_contiguous(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: for dtype in [torch.int32, torch.float32, torch.float64]: num_elements = torch.randint(20, 1000, (1, )).item() src_stride = torch.randint(2, num_elements // 10 + 1, size=(1, )).item() src = torch.randint(-1000, 1000, size=(num_elements, ), dtype=dtype, device=device) src = src[::src_stride] num_indexes = src.numel() * torch.randint(2, 10, (1, )).item() index = torch.randint(-1, src.numel(), size=(num_indexes, ), dtype=torch.int32, device=device) value_stride = torch.randint(2, 6, (1, )).item() value = torch.randint(-1000, 1000, size=(num_indexes * value_stride, ), dtype=dtype, device=device) value = value[::value_stride] assert src.is_contiguous() is False assert index.is_contiguous() assert value.is_contiguous() is False assert src.dtype == value.dtype == dtype assert index.dtype == torch.int32 assert src.device == value.device == index.device == device saved = src.clone() k2.index_add(index, value, src) saved = torch.cat([torch.tensor([0]).to(saved), saved]) saved.index_add_(0, index.to(torch.int64) + 1, value) assert torch.all(torch.eq(src, saved[1:]))
def test_contiguous(self): cpu_device = torch.device('cpu') cuda_device = torch.device('cuda', 0) for device in (cpu_device, cuda_device): num_elements = torch.randint(10, 1000, (1, )).item() src = torch.rand(num_elements, dtype=torch.float32).to(device) num_indexes = num_elements * torch.randint(2, 10, (1, )).item() index = torch.randint(-1, num_elements, (num_indexes, ), dtype=torch.int32).to(device) value = torch.rand(num_indexes, dtype=torch.float32).to(device) saved = src.clone() k2.index_add(index, value, src) saved = torch.cat([torch.tensor([0]).to(saved), saved]) saved.index_add_(0, index.to(torch.int64) + 1, value) assert torch.allclose(src, saved[1:])
def test_2d(self): for device in self.devices: for dtype in [torch.int32, torch.float32, torch.float64]: num_rows = torch.randint(10, 1000, (1,)).item() num_cols = torch.randint(10, 1000, (1,)).item() src = torch.randint(-1000, 1000, size=(num_rows, num_cols), dtype=dtype, device=device) num_indexes = num_rows * torch.randint(2, 10, (1,)).item() index = torch.randint(-1, num_rows, size=(num_indexes,), dtype=torch.int32, device=device) value = torch.randint(-1000, 1000, size=(num_indexes, num_cols), dtype=dtype, device=device) assert src.is_contiguous() assert index.is_contiguous() assert value.is_contiguous() assert src.dtype == value.dtype == dtype assert index.dtype == torch.int32 assert src.device == value.device == index.device == device saved = src.clone() k2.index_add(index, value, src) saved = torch.cat( [torch.zeros(1, saved.shape[1]).to(saved), saved]) saved.index_add_(0, index.to(torch.int64) + 1, value) assert torch.all(torch.eq(src, saved[1:]))
def my_func(index: torch.Tensor, value: torch.Tensor, src: torch.Tensor) -> torch.Tensor: saved = torch.zeros_like(src).to(torch.float32) k2.index_add(index, value, saved) return src + saved