def test_jit(self, device, dtype): B, C, H, W = 1, 1, 13, 13 inp = torch.zeros(B, C, H, W, device=device) inp[:, :, 15:-15, 9:-9] = 1 laf = torch.tensor([[[[20.0, 0.0, 16.0], [0.0, 20.0, 16.0]]]], device=device) tfeat = LAFAffineShapeEstimator(W).to(inp.device, inp.dtype).eval() tfeat_jit = torch.jit.script( LAFAffineShapeEstimator(W).to(inp.device, inp.dtype).eval()) assert_close(tfeat_jit(laf, inp), tfeat(laf, inp))
def test_toy(self, device): aff = LAFAffineShapeEstimator(32).to(device) inp = torch.zeros(1, 1, 32, 32, device=device) inp[:, :, 15:-15, 9:-9] = 1 laf = torch.tensor([[[[20.0, 0.0, 16.0], [0.0, 20.0, 16.0]]]], device=device) new_laf = aff(laf, inp) expected = torch.tensor([[[[36.643, 0.0, 16.0], [0.0, 10.916, 16.0]]]], device=device) assert_close(new_laf, expected, atol=1e-4, rtol=1e-4)
def test_toy_preserve(self, device, dtype): aff = LAFAffineShapeEstimator(32, preserve_orientation=True).to( device, dtype) inp = torch.zeros(1, 1, 32, 32, device=device, dtype=dtype) inp[:, :, 15:-15, 9:-9] = 1 laf = torch.tensor([[[[0.0, 20.0, 16.0], [-20.0, 0.0, 16.0]]]], device=device, dtype=dtype) new_laf = aff(laf, inp) expected = torch.tensor([[[[0.0, 36.643, 16.0], [-10.916, 0, 16.0]]]], device=device, dtype=dtype) assert_close(new_laf, expected, atol=1e-4, rtol=1e-4)
def test_gradcheck(self, device): batch_size, channels, height, width = 1, 1, 40, 40 patches = torch.rand(batch_size, channels, height, width, device=device) patches = utils.tensor_to_gradcheck_var(patches) # to var laf = torch.tensor([[[[5.0, 0.0, 26.0], [0.0, 5.0, 26.0]]]], device=device) laf = utils.tensor_to_gradcheck_var(laf) # to var assert gradcheck( LAFAffineShapeEstimator(11).to(device), (laf, patches), raise_exception=True, rtol=1e-3, atol=1e-3, nondet_tol=1e-4, )
def test_print(self, device): sift = LAFAffineShapeEstimator() sift.__repr__()
def test_shape_batch(self, device): inp = torch.rand(2, 1, 32, 32, device=device) laf = torch.rand(2, 34, 2, 3, device=device) ori = LAFAffineShapeEstimator().to(device) out = ori(laf, inp) assert out.shape == laf.shape