def test_jit(self, device, dtype): # Define script op = RandomCrop(size=(3, 3), p=1.).forward op_script = torch.jit.script(op) img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) actual = op_script(img) expected = kornia.center_crop3d(img) assert_allclose(actual, expected)
def test_jit(self, device, dtype): # Define script op = kornia.center_crop3d op_script = torch.jit.script(op) img = torch.ones(4, 3, 5, 6, 7, device=device, dtype=dtype) actual = op_script(img, (4, 3, 2)) expected = kornia.center_crop3d(img, (4, 3, 2)) assert_allclose(actual, expected)
def test_center_crop_357(self, crop_size, device, dtype): inp = torch.arange(0.0, 343.0, device=device, dtype=dtype).view(1, 1, 7, 7, 7) expected = inp[ :, :, (inp.size(2) // 2 - crop_size[0] // 2) : (inp.size(2) // 2 + crop_size[0] // 2 + 1), (inp.size(3) // 2 - crop_size[1] // 2) : (inp.size(3) // 2 + crop_size[1] // 2 + 1), (inp.size(4) // 2 - crop_size[2] // 2) : (inp.size(4) // 2 + crop_size[2] // 2 + 1), ] out_crop = kornia.center_crop3d(inp, crop_size, align_corners=True) assert_allclose(out_crop, expected, rtol=1e-4, atol=1e-4)
def test_center_crop_357_batch(self, crop_size, device, dtype): inp = torch.cat([ torch.arange(0., 343., device=device, dtype=dtype).view(1, 1, 7, 7, 7), torch.arange(343., 686., device=device, dtype=dtype).view(1, 1, 7, 7, 7) ]) expected = inp[:, :, (inp.size(2) // 2 - crop_size[0] // 2):(inp.size(2) // 2 + crop_size[0] // 2 + 1), (inp.size(3) // 2 - crop_size[1] // 2):(inp.size(3) // 2 + crop_size[1] // 2 + 1), (inp.size(4) // 2 - crop_size[2] // 2):(inp.size(4) // 2 + crop_size[2] // 2 + 1)] out_crop = kornia.center_crop3d(inp, crop_size, align_corners=True) assert_allclose(out_crop, expected)