def test_exception(self, batch_size, device, dtype): src = torch.rand(batch_size, 5, 2) kernel = torch.zeros_like(src) affine = torch.zeros(batch_size, 3, 2) with pytest.raises(TypeError): assert kornia.warp_points_tps(src.numpy(), src, kernel, affine) with pytest.raises(TypeError): assert kornia.warp_points_tps(src, src.numpy(), kernel, affine) with pytest.raises(TypeError): assert kornia.warp_points_tps(src, src, kernel.numpy(), affine) with pytest.raises(TypeError): assert kornia.warp_points_tps(src, src, kernel, affine.numpy()) with pytest.raises(ValueError): src_bad = torch.rand(batch_size, 5) assert kornia.warp_points_tps(src_bad, src, kernel, affine) with pytest.raises(ValueError): src_bad = torch.rand(batch_size, 5) assert kornia.warp_points_tps(src, src_bad, kernel, affine) with pytest.raises(ValueError): kernel_bad = torch.rand(batch_size, 5) assert kornia.warp_points_tps(src, src, kernel_bad, affine) with pytest.raises(ValueError): affine_bad = torch.rand(batch_size, 3) assert kornia.warp_points_tps(src, src, kernel, affine_bad)
def test_warp(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) kernel, affine = kornia.get_tps_transform(src, dst) warp = kornia.warp_points_tps(src, dst, kernel, affine) assert_allclose(warp, dst, atol=1e-4, rtol=1e-4)
def test_smoke(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) kernel, affine = kornia.get_tps_transform(src, dst) warp = kornia.warp_points_tps(src, dst, kernel, affine) assert warp.shape == src.shape