Esempio n. 1
0
 def test_jit(self, device, dtype):
     B, C, H, W = 2, 2, 13, 13
     patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
     model = (ExplicitSpacialEncoding(kernel_type='cart',
                                      fmap_size=W,
                                      in_dims=2).to(patches.device,
                                                    patches.dtype).eval())
     model_jit = torch.jit.script(
         ExplicitSpacialEncoding(kernel_type='cart', fmap_size=W,
                                 in_dims=2).to(patches.device,
                                               patches.dtype).eval())
     assert_close(model(patches), model_jit(patches))
Esempio n. 2
0
 def test_batch_shape(self, kernel_type, bs, device):
     inp = torch.ones(bs, 7, 15, 15).to(device)
     ese = ExplicitSpacialEncoding(kernel_type=kernel_type,
                                   fmap_size=15,
                                   in_dims=7).to(device)
     out = ese(inp)
     d_ = 9 if kernel_type == 'cart' else 25
     assert out.shape == (bs, d_ * 7)
Esempio n. 3
0
 def test_shape(self, kernel_type, ps, in_dims, device):
     inp = torch.ones(1, in_dims, ps, ps).to(device)
     ese = ExplicitSpacialEncoding(kernel_type=kernel_type,
                                   fmap_size=ps,
                                   in_dims=in_dims).to(device)
     out = ese(inp)
     d_ = 9 if kernel_type == 'cart' else 25
     assert out.shape == (1, d_ * in_dims)
Esempio n. 4
0
    def test_toy(self, device):
        inp = torch.ones(1, 2, 6, 6).to(device).float()
        inp[0, 0, :, :] = 0
        cart_ese = ExplicitSpacialEncoding(kernel_type='cart',
                                           fmap_size=6,
                                           in_dims=2).to(device)
        out = cart_ese(inp)
        out_part = out[:, :9]
        expected = torch.zeros_like(out_part).to(device)
        assert_close(out_part, expected, atol=1e-3, rtol=1e-3)

        polar_ese = ExplicitSpacialEncoding(kernel_type='polar',
                                            fmap_size=6,
                                            in_dims=2).to(device)
        out = polar_ese(inp)
        out_part = out[:, :25]
        expected = torch.zeros_like(out_part).to(device)
        assert_close(out_part, expected, atol=1e-3, rtol=1e-3)
Esempio n. 5
0
 def explicit_spatial_describe(patches, ps=13):
     ese = ExplicitSpacialEncoding(kernel_type=kernel_type,
                                   fmap_size=ps,
                                   in_dims=2)
     ese.to(device)
     return ese(patches)
Esempio n. 6
0
 def test_print(self, kernel_type, device):
     ese = ExplicitSpacialEncoding(kernel_type=kernel_type,
                                   fmap_size=15,
                                   in_dims=7).to(device)
     ese.__repr__()