def test_jit(self, device, dtype):
     B, C, H, W = 1, 1, 32, 32
     patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
     laf = torch.tensor([[[[8.0, 0.0, 16.0], [0.0, 8.0, 16.0]]]],
                        device=device)
     laf_estimator = LAFAffNetShapeEstimator(True).to(
         device, dtype=patches.dtype).eval()
     laf_estimator_jit = torch.jit.script(
         LAFAffNetShapeEstimator(True).to(device,
                                          dtype=patches.dtype).eval())
     assert_close(laf_estimator(laf, patches),
                  laf_estimator_jit(laf, patches))
 def test_toy(self, device):
     aff = LAFAffNetShapeEstimator(True).to(device).eval()
     inp = torch.zeros(1, 1, 32, 32, device=device)
     inp[:, :, 15:-15, 9:-9] = 1
     laf = torch.tensor([[[[20.0, 0.0, 16.0], [0.0, 20.0, 16.0]]]],
                        device=device)
     new_laf = aff(laf, inp)
     expected = torch.tensor(
         [[[[40.8758, 0.0, 16.0], [-0.3824, 9.7857, 16.0]]]], device=device)
     assert_close(new_laf, expected, atol=1e-4, rtol=1e-4)
 def test_gradcheck(self, device):
     batch_size, channels, height, width = 1, 1, 35, 35
     patches = torch.rand(batch_size,
                          channels,
                          height,
                          width,
                          device=device)
     patches = utils.tensor_to_gradcheck_var(patches)  # to var
     laf = torch.tensor([[[[8.0, 0.0, 16.0], [0.0, 8.0, 16.0]]]],
                        device=device)
     laf = utils.tensor_to_gradcheck_var(laf)  # to var
     assert gradcheck(
         LAFAffNetShapeEstimator(True).to(device, dtype=patches.dtype),
         (laf, patches),
         raise_exception=True,
         rtol=1e-3,
         atol=1e-3,
         nondet_tol=1e-4,
     )
 def test_print(self, device):
     sift = LAFAffNetShapeEstimator()
     sift.__repr__()
 def test_shape_batch(self, device):
     inp = torch.rand(2, 1, 32, 32, device=device)
     laf = torch.rand(2, 5, 2, 3, device=device)
     ori = LAFAffNetShapeEstimator().to(device).eval()
     out = ori(laf, inp)
     assert out.shape == laf.shape
 def test_pretrained(self, device):
     inp = torch.rand(1, 1, 32, 32, device=device)
     laf = torch.rand(1, 1, 2, 3, device=device)
     ori = LAFAffNetShapeEstimator(True).to(device).eval()
     out = ori(laf, inp)
     assert out.shape == laf.shape