Beispiel #1
0
 def test_jit(self, device, dtype):
     batch_size, channels, ps = 1, 1, 19
     patches = torch.rand(batch_size, channels, ps, ps).to(device)
     model = SimpleKD(patch_size=ps, kernel_type='polar',
                      whitening='lw').to(patches.device,
                                         patches.dtype).eval()
     model_jit = torch.jit.script(
         SimpleKD(patch_size=ps, kernel_type='polar',
                  whitening='lw').to(patches.device, patches.dtype).eval())
     assert_close(model(patches), model_jit(patches))
Beispiel #2
0
 def skd_describe(patches, patch_size=19):
     skd = SimpleKD(patch_size=ps, kernel_type='polar',
                    whitening='lw').double()
     skd.to(device)
     return skd(patches.double())
Beispiel #3
0
 def test_batch_shape(self, bs, device):
     skd = SimpleKD(patch_size=19, kernel_type='polar').to(device)
     inp = torch.ones(bs, 1, 19, 19).to(device)
     out = skd(inp)
     assert out.shape == (bs, 128)
Beispiel #4
0
 def test_print(self, device):
     skd = SimpleKD(patch_size=19, kernel_type='polar').to(device)
     skd.__repr__()
Beispiel #5
0
 def test_shape(self, ps, kernel_type, device):
     skd = SimpleKD(patch_size=ps, kernel_type=kernel_type).to(device)
     inp = torch.ones(1, 1, ps, ps).to(device)
     out = skd(inp)
     assert out.shape == (1, min(128, self.dims[kernel_type]))