def testDecompositionIsNonRedundant(self): """Test that wavelet construction is not redundant. If the wavelet decompositon is not redundant, then we should be able to 1) Construct a wavelet decomposition 2) Alter a single coefficient in the decomposition 3) Collapse that decomposition into an image and back and the two wavelet decompositions should be the same. """ for wavelet_type in wavelet.generate_filters(): for _ in range(4): # Construct an image and a wavelet decomposition of it. num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil( np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp]), ) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) # Pick a coefficient at random in the decomposition to alter. d = np.int32(np.floor(np.random.uniform() * len(pyr))) v = np.random.uniform() if d == (len(pyr) - 1): if np.prod(pyr[d].shape) > 0: c, i, j = np.int32( np.floor( np.array(np.random.uniform(size=3)) * pyr[d].shape)).tolist() pyr[d][c, i, j] = v else: b = np.int32(np.floor(np.random.uniform() * len(pyr[d]))) if np.prod(pyr[d][b].shape) > 0: c, i, j = np.int32( np.floor( np.array(np.random.uniform(size=3)) * pyr[d][b].shape)).tolist() pyr[d][b][c, i, j] = v # Collapse and then reconstruct the wavelet decomposition, and check # that it is unchanged. recon = wavelet.collapse(pyr, wavelet_type) pyr_again = wavelet.construct(recon, num_levels, wavelet_type) self._assert_pyramids_close(pyr, pyr_again, 1e-8)
def transform_to_mat(self, x): """Transforms a batch of images to a matrix.""" assert len(x.shape) == 4 x = torch.as_tensor(x) if self.color_space == 'YUV': x = util.rgb_to_syuv(x) # If `color_space` == 'RGB', do nothing. # Reshape `x` from # (num_batches, width, height, num_channels) to # (num_batches * num_channels, width, height) _, width, height, num_channels = x.shape x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height)) # Turn each channel in `x_stack` into the spatial representation specified # by `representation`. if self.representation in wavelet.generate_filters(): x_stack = wavelet.flatten( wavelet.rescale( wavelet.construct(x_stack, self.wavelet_num_levels, self.representation), self.wavelet_scale_base)) elif self.representation == 'DCT': x_stack = util.image_dct(x_stack) # If `representation` == 'PIXEL', do nothing. # Reshape `x_stack` from # (num_batches * num_channels, width, height) to # (num_batches, num_channels * width * height) x_mat = torch.reshape( torch.reshape(x_stack, (-1, num_channels, width, height)).permute( 0, 2, 3, 1), [-1, width * height * num_channels]) return x_mat
def testRescaleAndUnrescaleReproducesInput(self): """Tests that rescale(rescale(x, k), 1/k) = x.""" im = np.random.uniform(size=(2, 32, 32)) scale_base = np.exp(np.random.normal()) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, scale_base) pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base) self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
def testRescaleOneHalfIsNormalized(self): """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k.""" for num_levels in range(5): k = np.random.uniform() im = k * np.ones((2, 32, 32)) pyr = wavelet.construct(im, num_levels, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, 0.5) np.testing.assert_allclose(pyr_rescaled[-1], k * np.ones_like(pyr_rescaled[-1]), atol=1e-8, rtol=1e-8)
def testAccurateRoundTripWithSmallRandomImages(self): """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4].""" for wavelet_type in wavelet.generate_filters(): for width in range(0, 5): sz = [1, width, width] num_levels = wavelet.get_max_num_levels(sz) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
def testAccurateRoundTripWithLargeRandomImages(self): """Tests that collapse(construct(x)) == x for large random x's.""" for wavelet_type in wavelet.generate_filters(): for _ in range(4): num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp])) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
def testConstructMatchesGoldenData(self, device): """Tests construct() against golden data.""" im, pyr_true, wavelet_type = self._load_golden_data() im = torch.tensor(im, device=device) pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type) pyr = list(pyr) for d in range(len(pyr) - 1): pyr[d] = list(pyr[d]) for b in range(3): pyr[d][b] = pyr[d][b].cpu().detach() d = len(pyr) - 1 pyr[d] = pyr[d].cpu().detach() self._assert_pyramids_close(pyr, pyr_true, 1e-5)
def testWaveletTransformationIsVolumePreserving(self): """Tests that construct() is volume preserving when size is a power of 2.""" for wavelet_type in wavelet.generate_filters(): sz = (1, 4, 4) num_levels = 2 # Construct the Jacobian of construct(). im = np.float32(np.random.uniform(0., 1., sz)) jacobian = [] vec = lambda x: torch.reshape(x, [-1]) for d in range(im.size): var_im = torch.autograd.Variable(torch.tensor(im), requires_grad=True) coeff = vec( wavelet.flatten( wavelet.construct(var_im, num_levels, wavelet_type)))[d] coeff.backward() jacobian.append(np.reshape(var_im.grad.detach().numpy(), [-1])) jacobian = np.stack(jacobian, 1) # Assert that the determinant of the Jacobian is close to 1. det = np.linalg.det(jacobian) np.testing.assert_allclose(det, 1., atol=1e-5, rtol=1e-5)
def testRescaleDoesNotAffectTheFirstLevel(self): """Tests that rescale(x, s)[0] = x[0] for any s.""" im = np.random.uniform(size=(2, 32, 32)) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal())) self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
def testRescaleOneIsANoOp(self): """Tests that rescale(x, 1) = x.""" im = np.random.uniform(size=(2, 32, 32)) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, 1.) self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)
def _construct_preserves_dtype(self, float_dtype): """Checks that construct()'s output has the same precision as its input.""" x = float_dtype(np.random.normal(size=(3, 16, 16))) for wavelet_type in wavelet.generate_filters(): y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type)) np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)
def testConstructMatchesGoldenData(self): """Tests construct() against golden data.""" im, pyr_true, wavelet_type = self._load_golden_data() pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type) self._assert_pyramids_close(pyr, pyr_true, 1e-5)