def _generate_wavelet_toy_image_data(image_width, num_samples, wavelet_num_levels): """Generates wavelet data for testFittingImageDataIsCorrect(). Constructs a "mean" image in the YUV wavelet domain (parametrized by `image_width`, and `wavelet_num_levels`) and draws `num_samples` samples from a normal distribution using that mean, and returns RGB images corresponding to those samples and to the mean (computed in the specified latent space) of those samples. Args: image_width: The width and height in pixels of the images being produced. num_samples: The number of samples to generate. wavelet_num_levels: The number of levels in the wavelet decompositions of the generated images. Returns: A tuple of (samples, reference, color_space, representation), where samples = A set of sampled images of size (`num_samples`, `image_width`, `image_width`, 3) reference = The empirical mean of `samples` (computed in YUV Wavelet space but returned as an RGB image) of size (`image_width`, `image_width`, 3). color_space = 'YUV' representation = 'CDF9/7' """ color_space = 'YUV' representation = 'CDF9/7' samples = [] reference = [] for level in range(wavelet_num_levels): samples.append([]) reference.append([]) w = image_width // 2**(level + 1) scaling = 2**level for _ in range(3): # Construct the ground-truth pixel band mean. mu = scaling * np.random.uniform(size=(3, w, w)) # Draw samples from the ground-truth mean. band_samples = np.random.normal( loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1])) # Take the empirical mean of the samples as a reference. band_reference = np.mean(band_samples, 0) samples[-1].append(np.reshape(band_samples, [-1, w, w])) reference[-1].append(band_reference) # Handle the residual band. mu = scaling * np.random.uniform(size=(3, w, w)) band_samples = np.random.normal( loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1])) band_reference = np.mean(band_samples, 0) samples.append(np.reshape(band_samples, [-1, w, w])) reference.append(band_reference) # Collapse and reshape wavelets to be ({_,} width, height, 3). samples = wavelet.collapse(samples, representation) reference = wavelet.collapse(reference, representation) samples = tf.transpose( tf.reshape(samples, [num_samples, 3, image_width, image_width]), [0, 2, 3, 1]) reference = tf.transpose(reference, [1, 2, 0]) # Convert into RGB space. samples = util.syuv_to_rgb(samples) reference = util.syuv_to_rgb(reference) with tf.Session() as sess: samples, reference = sess.run((samples, reference)) return samples, reference, color_space, representation
def testRgbToSyuvRoundTrip(self): """Tests that syuv_to_rgb(rgb_to_syuv(x)) == x.""" rgb = np.float32(np.random.uniform(size=(32, 32, 3))) syuv = util.rgb_to_syuv(rgb) rgb_recon = util.syuv_to_rgb(syuv) self.assertAllClose(rgb, rgb_recon)