コード例 #1
0
def _generate_wavelet_toy_image_data(image_width, num_samples,
                                     wavelet_num_levels):
    """Generates wavelet data for testFittingImageDataIsCorrect().

  Constructs a "mean" image in the YUV wavelet domain (parametrized by
  `image_width`, and `wavelet_num_levels`) and draws `num_samples` samples
  from a normal distribution using that mean, and returns RGB images
  corresponding to those samples and to the mean (computed in the
  specified latent space) of those samples.

  Args:
    image_width: The width and height in pixels of the images being produced.
    num_samples: The number of samples to generate.
    wavelet_num_levels: The number of levels in the wavelet decompositions of
      the generated images.

  Returns:
    A tuple of (samples, reference, color_space, representation), where
    samples = A set of sampled images of size
      (`num_samples`, `image_width`, `image_width`, 3)
    reference = The empirical mean of `samples` (computed in YUV Wavelet space
      but returned as an RGB image) of size (`image_width`, `image_width`, 3).
    color_space = 'YUV'
    representation = 'CDF9/7'
  """
    color_space = 'YUV'
    representation = 'CDF9/7'
    samples = []
    reference = []
    for level in range(wavelet_num_levels):
        samples.append([])
        reference.append([])
        w = image_width // 2**(level + 1)
        scaling = 2**level
        for _ in range(3):
            # Construct the ground-truth pixel band mean.
            mu = scaling * np.random.uniform(size=(3, w, w))
            # Draw samples from the ground-truth mean.
            band_samples = np.random.normal(
                loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1]))
            # Take the empirical mean of the samples as a reference.
            band_reference = np.mean(band_samples, 0)
            samples[-1].append(np.reshape(band_samples, [-1, w, w]))
            reference[-1].append(band_reference)
    # Handle the residual band.
    mu = scaling * np.random.uniform(size=(3, w, w))
    band_samples = np.random.normal(
        loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1]))
    band_reference = np.mean(band_samples, 0)
    samples.append(np.reshape(band_samples, [-1, w, w]))
    reference.append(band_reference)
    # Collapse and reshape wavelets to be ({_,} width, height, 3).
    samples = wavelet.collapse(samples, representation)
    reference = wavelet.collapse(reference, representation)
    samples = tf.transpose(
        tf.reshape(samples, [num_samples, 3, image_width, image_width]),
        [0, 2, 3, 1])
    reference = tf.transpose(reference, [1, 2, 0])
    # Convert into RGB space.
    samples = util.syuv_to_rgb(samples)
    reference = util.syuv_to_rgb(reference)
    with tf.Session() as sess:
        samples, reference = sess.run((samples, reference))
    return samples, reference, color_space, representation
コード例 #2
0
 def testRgbToSyuvRoundTrip(self):
     """Tests that syuv_to_rgb(rgb_to_syuv(x)) == x."""
     rgb = np.float32(np.random.uniform(size=(32, 32, 3)))
     syuv = util.rgb_to_syuv(rgb)
     rgb_recon = util.syuv_to_rgb(syuv)
     self.assertAllClose(rgb, rgb_recon)