def test_batch_shape(self, bs, device): mkd = MKDDescriptor(patch_size=19, kernel_type='concat', whitening=None).to(device) inp = torch.ones(bs, 1, 19, 19).to(device) out = mkd(inp) assert out.shape == (bs, 238)
def test_shape(self, ps, kernel_type, device): mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=None).to(device) inp = torch.ones(1, 1, ps, ps).to(device) out = mkd(inp) assert out.shape == (1, self.dims[kernel_type])
def test_whitened_shape(self, ps, kernel_type, whitening, device): mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=whitening).to(device) inp = torch.ones(1, 1, ps, ps).to(device) out = mkd(inp) output_dims = min(self.dims[kernel_type], 128) assert out.shape == (1, output_dims)
def test_toy(self, device): inp = torch.ones(1, 1, 6, 6).to(device).float() inp[0, 0, :, :] = 0 mkd = MKDDescriptor(patch_size=6, kernel_type='concat', whitening=None).to(device) out = mkd(inp) out_part = out[0, -28:] expected = torch.zeros_like(out_part).to(device) assert_allclose(out_part, expected, atol=1e-3, rtol=1e-3)
def mkd_describe(patches, patch_size=19): mkd = MKDDescriptor(patch_size=patch_size, kernel_type='concat', whitening=whitening).double() mkd.to(device) return mkd(patches.double())