def test_jit(self, device, dtype): batch_size, channels, ps = 1, 1, 19 patches = torch.rand(batch_size, channels, ps, ps).to(device) model = SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').to(patches.device, patches.dtype).eval() model_jit = torch.jit.script( SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').to(patches.device, patches.dtype).eval()) assert_close(model(patches), model_jit(patches))
def skd_describe(patches, patch_size=19): skd = SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').double() skd.to(device) return skd(patches.double())
def test_batch_shape(self, bs, device): skd = SimpleKD(patch_size=19, kernel_type='polar').to(device) inp = torch.ones(bs, 1, 19, 19).to(device) out = skd(inp) assert out.shape == (bs, 128)
def test_print(self, device): skd = SimpleKD(patch_size=19, kernel_type='polar').to(device) skd.__repr__()
def test_shape(self, ps, kernel_type, device): skd = SimpleKD(patch_size=ps, kernel_type=kernel_type).to(device) inp = torch.ones(1, 1, ps, ps).to(device) out = skd(inp) assert out.shape == (1, min(128, self.dims[kernel_type]))