def test_batch(self, batch_size, device, dtype): B: int = batch_size center = torch.rand(B, 3, device=device, dtype=dtype) angle = torch.rand(B, 3, device=device, dtype=dtype) scales: torch.Tensor = torch.ones_like(angle, device=device, dtype=dtype) P = proj.get_projective_transform(center, angle, scales) assert P.shape == (B, 3, 4)
def test_rot90z(self, device, dtype): center = torch.zeros(1, 3, device=device, dtype=dtype) angle = torch.tensor([[0.0, 0.0, 90.0]], device=device, dtype=dtype) scales: torch.Tensor = torch.ones_like(angle, device=device, dtype=dtype) P = proj.get_projective_transform(center, angle, scales) P_expected = torch.tensor( [[0.0, -1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], device=device, dtype=dtype ).unsqueeze(0) assert_close(P, P_expected, atol=1e-4, rtol=1e-4)
def test_rotate_y_large(self, device, dtype): """Rotates 90deg anti-clockwise.""" input = torch.tensor( [ [ [ [[0.0, 4.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 2.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ], [ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 9.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 6.0, 7.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 0.0]], ], ] ], device=device, dtype=dtype, ) expected = torch.tensor( [ [ [ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[4.0, 2.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ], [ [[0.0, 0.0, 0.0], [0.0, 7.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 6.0, 8.0], [9.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ], ] ], device=device, dtype=dtype, ) _, _, D, H, W = input.shape center = torch.tensor([[(W - 1) / 2, (H - 1) / 2, (D - 1) / 2]], device=device, dtype=dtype) angles = torch.tensor([[0.0, 90.0, 0.0]], device=device, dtype=dtype) scales: torch.Tensor = torch.ones_like(angles, device=device, dtype=dtype) P = proj.get_projective_transform(center, angles, scales) output = proj.warp_affine3d(input, P, (3, 3, 3)) assert_close(output, expected, rtol=1e-4, atol=1e-4)
def test_rotate_y(self, device, dtype): input = torch.tensor( [ [ [ [[0.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ] ] ], device=device, dtype=dtype, ) expected = torch.tensor( [ [ [ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [2.0, 1.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ] ] ], device=device, dtype=dtype, ) _, _, D, H, W = input.shape center = torch.tensor([[(W - 1) / 2, (H - 1) / 2, (D - 1) / 2]], device=device, dtype=dtype) angles = torch.tensor([[0.0, 90.0, 0.0]], device=device, dtype=dtype) scales: torch.Tensor = torch.ones_like(angles, device=device, dtype=dtype) P = proj.get_projective_transform(center, angles, scales) output = proj.warp_affine3d(input, P, (3, 3, 3)) assert_close(output, expected, rtol=1e-4, atol=1e-4)
def test_smoke(self, device, dtype): center = torch.rand(1, 3, device=device, dtype=dtype) angle = torch.rand(1, 3, device=device, dtype=dtype) scales: torch.Tensor = torch.ones_like(angle, device=device, dtype=dtype) P = proj.get_projective_transform(center, angle, scales) assert P.shape == (1, 3, 4)