def test_jit(self, device, dtype): batch_size, channels, ps = 1, 1, 19 patches = torch.rand(batch_size, channels, ps, ps).to(device) kt = 'concat' wt = 'lw' model = MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval() model_jit = torch.jit.script( MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval()) assert_close(model(patches), model_jit(patches))
def test_batch_shape(self, bs, device): mkd = MKDDescriptor(patch_size=19, kernel_type='concat', whitening=None).to(device) inp = torch.ones(bs, 1, 19, 19).to(device) out = mkd(inp) assert out.shape == (bs, 238)
def test_shape(self, ps, kernel_type, device): mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=None).to(device) inp = torch.ones(1, 1, ps, ps).to(device) out = mkd(inp) assert out.shape == (1, self.dims[kernel_type])
def test_whitened_shape(self, ps, kernel_type, whitening, device): mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=whitening).to(device) inp = torch.ones(1, 1, ps, ps).to(device) out = mkd(inp) output_dims = min(self.dims[kernel_type], 128) assert out.shape == (1, output_dims)
def test_toy(self, device): inp = torch.ones(1, 1, 6, 6).to(device).float() inp[0, 0, :, :] = 0 mkd = MKDDescriptor(patch_size=6, kernel_type='concat', whitening=None).to(device) out = mkd(inp) out_part = out[0, -28:] expected = torch.zeros_like(out_part).to(device) assert_close(out_part, expected, atol=1e-3, rtol=1e-3)
def mkd_describe(patches, patch_size=19): mkd = MKDDescriptor(patch_size=patch_size, kernel_type='concat', whitening=whitening).double() mkd.to(device) return mkd(patches.double())
def test_print(self, device): mkd = MKDDescriptor(patch_size=32, whitening='lw', training_set='liberty', output_dims=128).to(device) mkd.__repr__()