Ejemplo n.º 1
0
 def test_jit(self, device, dtype):
     B, C, H, W = 2, 1, 13, 13
     patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
     model = VonMisesKernel(patch_size=13,
                            coeffs=[0.38214156,
                                    0.48090413]).to(patches.device,
                                                    patches.dtype).eval()
     model_jit = torch.jit.script(
         VonMisesKernel(patch_size=13,
                        coeffs=[0.38214156,
                                0.48090413]).to(patches.device,
                                                patches.dtype).eval())
     assert_close(model(patches), model_jit(patches))
Ejemplo n.º 2
0
    def test_toy(self, device):
        patch = torch.ones(1, 1, 6, 6).float().to(device)
        patch[0, 0, :, 3:] = 0
        vm = VonMisesKernel(patch_size=6, coeffs=[0.38214156,
                                                  0.48090413]).to(device)
        out = vm(patch)
        expected = torch.ones_like(out[0, 0, :, :]).to(device)
        assert_close(out[0, 0, :, :], expected * 0.6182, atol=1e-3, rtol=1e-3)

        expected = torch.Tensor(
            [0.3747, 0.3747, 0.3747, 0.6935, 0.6935, 0.6935]).to(device)
        expected = expected.unsqueeze(0).repeat(6, 1)
        assert_close(out[0, 1, :, :], expected, atol=1e-3, rtol=1e-3)

        expected = torch.Tensor(
            [0.5835, 0.5835, 0.5835, 0.0000, 0.0000, 0.0000]).to(device)
        expected = expected.unsqueeze(0).repeat(6, 1)
        assert_close(out[0, 2, :, :], expected, atol=1e-3, rtol=1e-3)
Ejemplo n.º 3
0
 def test_shape(self, ps, device):
     inp = torch.ones(1, 1, ps, ps).to(device)
     vm = VonMisesKernel(patch_size=ps, coeffs=[0.38214156,
                                                0.48090413]).to(device)
     out = vm(inp)
     assert out.shape == (1, 3, ps, ps)
Ejemplo n.º 4
0
 def vm_describe(patches, ps=13):
     vmkernel = VonMisesKernel(patch_size=ps,
                               coeffs=[0.38214156,
                                       0.48090413]).double()
     vmkernel.to(device)
     return vmkernel(patches.double())
Ejemplo n.º 5
0
 def test_print(self, device):
     vm = VonMisesKernel(patch_size=32, coeffs=[0.38214156,
                                                0.48090413]).to(device)
     vm.__repr__()
Ejemplo n.º 6
0
 def test_coeffs(self, coeffs, device):
     inp = torch.ones(1, 1, 15, 15).to(device)
     vm = VonMisesKernel(patch_size=15, coeffs=coeffs).to(device)
     out = vm(inp)
     assert out.shape == (1, 2 * len(coeffs) - 1, 15, 15)
Ejemplo n.º 7
0
 def test_batch_shape(self, bs, device):
     inp = torch.ones(bs, 1, 15, 15).to(device)
     vm = VonMisesKernel(patch_size=15, coeffs=[0.38214156,
                                                0.48090413]).to(device)
     out = vm(inp)
     assert out.shape == (bs, 3, 15, 15)