def test_even(self, device): kernel = _get_center_kernel3d(2, 4, 3).to(device) expected = torch.zeros(3, 3, 2, 4, 3).to(device) expected[0, 0, :, 1:3, 1] = 0.25 expected[1, 1, :, 1:3, 1] = 0.25 expected[2, 2, :, 1:3, 1] = 0.25 assert_allclose(kernel, expected)
def test_odd(self, device): kernel = _get_center_kernel3d(3, 5, 7).to(device) expected = torch.zeros(3, 3, 3, 5, 7).to(device) expected[0, 0, 1, 2, 3] = 1. expected[1, 1, 1, 2, 3] = 1. expected[2, 2, 1, 2, 3] = 1. assert_allclose(kernel, expected)
def test_even(self, device, dtype): kernel = _get_center_kernel3d(2, 4, 3, device=device).to(dtype=dtype) expected = torch.zeros(3, 3, 2, 4, 3, device=device, dtype=dtype) expected[0, 0, :, 1:3, 1] = 0.25 expected[1, 1, :, 1:3, 1] = 0.25 expected[2, 2, :, 1:3, 1] = 0.25 assert_allclose(kernel, expected, atol=1e-4, rtol=1e-4)
def test_odd(self, device, dtype): kernel = _get_center_kernel3d(3, 5, 7, device=device).to(dtype=dtype) expected = torch.zeros(3, 3, 3, 5, 7, device=device, dtype=dtype) expected[0, 0, 1, 2, 3] = 1. expected[1, 1, 1, 2, 3] = 1. expected[2, 2, 1, 2, 3] = 1. assert_allclose(kernel, expected, atol=1e-4, rtol=1e-4)
def test_smoke(self, device): kernel = _get_center_kernel3d(6, 3, 4).to(device) assert kernel.shape == (3, 3, 6, 3, 4)