def testAccurateRoundTripWithSmallRandomImages(self): """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4].""" for wavelet_type in wavelet.generate_filters(): for width in range(0, 5): sz = [1, width, width] num_levels = wavelet.get_max_num_levels(sz) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
def _collapse_preserves_dtype(self, float_dtype): """Checks that collapse()'s output has the same precision as its input.""" n = 16 x = [] for n in [8, 4, 2]: band = [] for _ in range(3): band.append(float_dtype(np.random.normal(size=(3, n, n)))) x.append(band) x.append(float_dtype(np.random.normal(size=(3, n, n)))) for wavelet_type in wavelet.generate_filters(): y = wavelet.collapse(x, wavelet_type) np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)
def testAccurateRoundTripWithLargeRandomImages(self): """Tests that collapse(construct(x)) == x for large random x's.""" for wavelet_type in wavelet.generate_filters(): for _ in range(4): num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp])) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
def testCollapseMatchesGoldenData(self, device): """Tests collapse() against golden data.""" im, pyr_true, wavelet_type = self._load_golden_data() pyr_true = list(pyr_true) for d in range(len(pyr_true) - 1): pyr_true[d] = list(pyr_true[d]) for b in range(3): pyr_true[d][b] = torch.tensor(pyr_true[d][b], device=device) d = len(pyr_true) - 1 pyr_true[d] = torch.tensor(pyr_true[d], device=device) recon = wavelet.collapse(pyr_true, wavelet_type).cpu().detach() np.testing.assert_allclose(recon, im, atol=1e-5, rtol=1e-5)
def testDecompositionIsNonRedundant(self): """Test that wavelet construction is not redundant. If the wavelet decompositon is not redundant, then we should be able to 1) Construct a wavelet decomposition 2) Alter a single coefficient in the decomposition 3) Collapse that decomposition into an image and back and the two wavelet decompositions should be the same. """ for wavelet_type in wavelet.generate_filters(): for _ in range(4): # Construct an image and a wavelet decomposition of it. num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil( np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp]), ) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) # Pick a coefficient at random in the decomposition to alter. d = np.int32(np.floor(np.random.uniform() * len(pyr))) v = np.random.uniform() if d == (len(pyr) - 1): if np.prod(pyr[d].shape) > 0: c, i, j = np.int32( np.floor( np.array(np.random.uniform(size=3)) * pyr[d].shape)).tolist() pyr[d][c, i, j] = v else: b = np.int32(np.floor(np.random.uniform() * len(pyr[d]))) if np.prod(pyr[d][b].shape) > 0: c, i, j = np.int32( np.floor( np.array(np.random.uniform(size=3)) * pyr[d][b].shape)).tolist() pyr[d][b][c, i, j] = v # Collapse and then reconstruct the wavelet decomposition, and check # that it is unchanged. recon = wavelet.collapse(pyr, wavelet_type) pyr_again = wavelet.construct(recon, num_levels, wavelet_type) self._assert_pyramids_close(pyr, pyr_again, 1e-8)
def testCollapseMatchesGoldenData(self): """Tests collapse() against golden data.""" im, pyr_true, wavelet_type = self._load_golden_data() recon = wavelet.collapse(pyr_true, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-5, rtol=1e-5)
def _generate_wavelet_toy_image_data(image_width, num_samples, wavelet_num_levels): """Generates wavelet data for testFittingImageDataIsCorrect(). Constructs a "mean" image in the YUV wavelet domain (parametrized by `image_width`, and `wavelet_num_levels`) and draws `num_samples` samples from a normal distribution using that mean, and returns RGB images corresponding to those samples and to the mean (computed in the specified latent space) of those samples. Args: image_width: The width and height in pixels of the images being produced. num_samples: The number of samples to generate. wavelet_num_levels: The number of levels in the wavelet decompositions of the generated images. Returns: A tuple of (samples, reference, color_space, representation), where samples = A set of sampled images of size (`num_samples`, `image_width`, `image_width`, 3) reference = The empirical mean of `samples` (computed in YUV Wavelet space but returned as an RGB image) of size (`image_width`, `image_width`, 3). color_space = 'YUV' representation = 'CDF9/7' """ color_space = 'YUV' representation = 'CDF9/7' samples = [] reference = [] for level in range(wavelet_num_levels): samples.append([]) reference.append([]) w = image_width // 2**(level + 1) scaling = 2**level for _ in range(3): # Construct the ground-truth pixel band mean. mu = scaling * np.random.uniform(size=(3, w, w)) # Draw samples from the ground-truth mean. band_samples = np.random.normal( loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1])) # Take the empirical mean of the samples as a reference. band_reference = np.mean(band_samples, 0) samples[-1].append(np.reshape(band_samples, [-1, w, w])) reference[-1].append(band_reference) # Handle the residual band. mu = scaling * np.random.uniform(size=(3, w, w)) band_samples = np.random.normal( loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1])) band_reference = np.mean(band_samples, 0) samples.append(np.reshape(band_samples, [-1, w, w])) reference.append(band_reference) # Collapse and reshape wavelets to be ({_,} width, height, 3). samples = wavelet.collapse(samples, representation) reference = wavelet.collapse(reference, representation) samples = np.transpose( np.reshape(samples, [num_samples, 3, image_width, image_width]), [0, 2, 3, 1]) reference = np.transpose(reference, [1, 2, 0]) # Convert into RGB space. samples = util.syuv_to_rgb(samples) reference = util.syuv_to_rgb(reference) samples = samples.detach().numpy() reference = reference.detach().numpy() return samples, reference, color_space, representation