コード例 #1
0
 def test_jit(self, device, dtype):
     batch_size, channels, ps = 1, 1, 19
     patches = torch.rand(batch_size, channels, ps, ps).to(device)
     kt = 'concat'
     wt = 'lw'
     model = MKDDescriptor(patch_size=ps, kernel_type=kt,
                           whitening=wt).to(patches.device,
                                            patches.dtype).eval()
     model_jit = torch.jit.script(
         MKDDescriptor(patch_size=ps, kernel_type=kt,
                       whitening=wt).to(patches.device,
                                        patches.dtype).eval())
     assert_close(model(patches), model_jit(patches))
コード例 #2
0
 def test_batch_shape(self, bs, device):
     mkd = MKDDescriptor(patch_size=19,
                         kernel_type='concat',
                         whitening=None).to(device)
     inp = torch.ones(bs, 1, 19, 19).to(device)
     out = mkd(inp)
     assert out.shape == (bs, 238)
コード例 #3
0
 def test_shape(self, ps, kernel_type, device):
     mkd = MKDDescriptor(patch_size=ps,
                         kernel_type=kernel_type,
                         whitening=None).to(device)
     inp = torch.ones(1, 1, ps, ps).to(device)
     out = mkd(inp)
     assert out.shape == (1, self.dims[kernel_type])
コード例 #4
0
 def test_whitened_shape(self, ps, kernel_type, whitening, device):
     mkd = MKDDescriptor(patch_size=ps,
                         kernel_type=kernel_type,
                         whitening=whitening).to(device)
     inp = torch.ones(1, 1, ps, ps).to(device)
     out = mkd(inp)
     output_dims = min(self.dims[kernel_type], 128)
     assert out.shape == (1, output_dims)
コード例 #5
0
 def test_toy(self, device):
     inp = torch.ones(1, 1, 6, 6).to(device).float()
     inp[0, 0, :, :] = 0
     mkd = MKDDescriptor(patch_size=6, kernel_type='concat',
                         whitening=None).to(device)
     out = mkd(inp)
     out_part = out[0, -28:]
     expected = torch.zeros_like(out_part).to(device)
     assert_close(out_part, expected, atol=1e-3, rtol=1e-3)
コード例 #6
0
 def mkd_describe(patches, patch_size=19):
     mkd = MKDDescriptor(patch_size=patch_size,
                         kernel_type='concat',
                         whitening=whitening).double()
     mkd.to(device)
     return mkd(patches.double())
コード例 #7
0
 def test_print(self, device):
     mkd = MKDDescriptor(patch_size=32,
                         whitening='lw',
                         training_set='liberty',
                         output_dims=128).to(device)
     mkd.__repr__()