def testDecompositionIsNonRedundant(self): """Test that wavelet construction is not redundant. If the wavelet decompositon is not redundant, then we should be able to 1) Construct a wavelet decomposition 2) Alter a single coefficient in the decomposition 3) Collapse that decomposition into an image and back and the two wavelet decompositions should be the same. """ for wavelet_type in wavelet.generate_filters(): for _ in range(4): # Construct an image and a wavelet decomposition of it. num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp])) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) pyr = list(pyr) # Pick a coefficient at random in the decomposition to alter. d = np.int32(np.floor(np.random.uniform() * len(pyr))) v = np.random.uniform() if d == (len(pyr) - 1): if np.prod(pyr[d].shape) > 0: c, i, j = np.int32( np.floor(np.array(np.random.uniform(size=3)) * pyr[d].shape)).tolist() pyr[d] = pyr[d].numpy() pyr[d][c, i, j] = v else: b = np.int32(np.floor(np.random.uniform() * len(pyr[d]))) if np.prod(pyr[d][b].shape) > 0: c, i, j = np.int32( np.floor(np.array(np.random.uniform(size=3)) * pyr[d][b].shape)).tolist() pyr[d] = list(pyr[d]) pyr[d][b] = pyr[d][b].numpy() pyr[d][b][c, i, j] = v # Collapse and then reconstruct the wavelet decomposition, and check # that it is unchanged. recon = wavelet.collapse(pyr, wavelet_type) pyr_again = wavelet.construct(recon, num_levels, wavelet_type) self._assert_pyramids_close(pyr, pyr_again, 1e-8)
def testRescaleAndUnrescaleReproducesInput(self): """Tests that rescale(rescale(x, k), 1/k) = x.""" im = np.random.uniform(size=(2, 32, 32)) scale_base = np.exp(np.random.normal()) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, scale_base) pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base) self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
def testAccurateRoundTripWithSmallRandomImages(self): """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4].""" for wavelet_type in wavelet.generate_filters(): for width in range(1, 5): sz = [1, width, width] num_levels = wavelet.get_max_num_levels(sz) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
def testRescaleOneHalfIsNormalized(self): """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k.""" for num_levels in range(5): k = np.random.uniform() im = k * np.ones((2, 32, 32)) pyr = wavelet.construct(im, num_levels, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, 0.5) self.assertAllClose( pyr_rescaled[-1], k * np.ones_like(pyr_rescaled[-1]), atol=1e-8, rtol=1e-8)
def testAccurateRoundTripWithLargeRandomImages(self): """Tests that collapse(construct(x)) == x for large random x's.""" for wavelet_type in wavelet.generate_filters(): for _ in range(4): num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp])) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
def testRescaleDoesNotAffectTheFirstLevel(self): """Tests that rescale(x, s)[0] = x[0] for any s.""" im = np.random.uniform(size=(2, 32, 32)) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal())) self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
def testRescaleOneIsANoOp(self): """Tests that rescale(x, 1) = x.""" im = np.random.uniform(size=(2, 32, 32)) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, 1.) self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)
def testConstructPreservesDtype(self, float_dtype): """Checks that construct()'s output has the same precision as its input.""" x = float_dtype(np.random.normal(size=(3, 16, 16))) for wavelet_type in wavelet.generate_filters(): y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type)) self.assertDTypeEqual(y, float_dtype)
def testConstructMatchesGoldenData(self): """Tests construct() against golden data.""" im, pyr_true, wavelet_type = self._load_golden_data() pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type) self._assert_pyramids_close(pyr, pyr_true, 1e-5)
def fun(z): # pylint: disable=cell-var-from-loop return wavelet.flatten(wavelet.construct(z, num_levels, wavelet_type))