Пример #1
0
 def testCollapseMatchesGoldenData(self):
     """Tests collapse() against golden data."""
     im, pyr_true, wavelet_type = self._load_golden_data()
     recon = wavelet.collapse(pyr_true, wavelet_type)
     with self.session() as sess:
         recon = sess.run(recon)
     self.assertAllClose(recon, im, atol=1e-5, rtol=1e-5)
Пример #2
0
  def testAccurateRoundTripWithSmallRandomImages(self):
    """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4]."""
    for wavelet_type in wavelet.generate_filters():
      for width in range(1, 5):
        sz = [1, width, width]
        num_levels = wavelet.get_max_num_levels(sz)
        im = np.random.uniform(size=sz)

        pyr = wavelet.construct(im, num_levels, wavelet_type)
        recon = wavelet.collapse(pyr, wavelet_type)
        self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
Пример #3
0
 def testCollapsePreservesDtype(self, float_dtype):
   """Checks that collapse()'s output has the same precision as its input."""
   n = 16
   x = []
   for n in [8, 4, 2]:
     band = []
     for _ in range(3):
       band.append(float_dtype(np.random.normal(size=(3, n, n))))
     x.append(band)
   x.append(float_dtype(np.random.normal(size=(3, n, n))))
   for wavelet_type in wavelet.generate_filters():
     y = wavelet.collapse(x, wavelet_type)
     self.assertDTypeEqual(y, float_dtype)
Пример #4
0
    def testDecompositionIsNonRedundant(self):
        """Test that wavelet construction is not redundant.

    If the wavelet decompositon is not redundant, then we should be able to
    1) Construct a wavelet decomposition
    2) Alter a single coefficient in the decomposition
    3) Collapse that decomposition into an image and back
    and the two wavelet decompositions should be the same.
    """
        for wavelet_type in wavelet.generate_filters():
            for _ in range(4):
                # Construct an image and a wavelet decomposition of it.
                num_levels = np.int32(np.ceil(4 * np.random.uniform()))
                sz_clamp = 2**(num_levels - 1) + 1
                sz = np.maximum(
                    np.int32(
                        np.ceil(
                            np.array([2, 32, 32]) *
                            np.random.uniform(size=3))),
                    np.array([0, sz_clamp, sz_clamp]))
                im = np.random.uniform(size=sz)
                pyr = wavelet.construct(im, num_levels, wavelet_type)
                pyr = list(pyr)

            # Pick a coefficient at random in the decomposition to alter.
            d = np.int32(np.floor(np.random.uniform() * len(pyr)))
            v = np.random.uniform()
            if d == (len(pyr) - 1):
                if np.prod(pyr[d].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d].shape)).tolist()
                    pyr[d] = pyr[d].numpy()
                    pyr[d][c, i, j] = v
            else:
                b = np.int32(np.floor(np.random.uniform() * len(pyr[d])))
                if np.prod(pyr[d][b].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d][b].shape)).tolist()
                    pyr[d] = list(pyr[d])
                    pyr[d][b] = pyr[d][b].numpy()
                    pyr[d][b][c, i, j] = v

            # Collapse and then reconstruct the wavelet decomposition, and check
            # that it is unchanged.
            recon = wavelet.collapse(pyr, wavelet_type)
            pyr_again = wavelet.construct(recon, num_levels, wavelet_type)
            self._assert_pyramids_close(pyr, pyr_again, 1e-8)
Пример #5
0
 def testAccurateRoundTripWithLargeRandomImages(self):
   """Tests that collapse(construct(x)) == x for large random x's."""
   for wavelet_type in wavelet.generate_filters():
     for _ in range(4):
       num_levels = np.int32(np.ceil(4 * np.random.uniform()))
       sz_clamp = 2**(num_levels - 1) + 1
       sz = np.maximum(
           np.int32(
               np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
           np.array([0, sz_clamp, sz_clamp]))
       im = np.random.uniform(size=sz)
       pyr = wavelet.construct(im, num_levels, wavelet_type)
       recon = wavelet.collapse(pyr, wavelet_type)
       self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
Пример #6
0
def _generate_wavelet_toy_image_data(image_width, num_samples,
                                     wavelet_num_levels):
    """Generates wavelet data for testFittingImageDataIsCorrect().

  Constructs a "mean" image in the YUV wavelet domain (parametrized by
  `image_width`, and `wavelet_num_levels`) and draws `num_samples` samples
  from a normal distribution using that mean, and returns RGB images
  corresponding to those samples and to the mean (computed in the
  specified latent space) of those samples.

  Args:
    image_width: The width and height in pixels of the images being produced.
    num_samples: The number of samples to generate.
    wavelet_num_levels: The number of levels in the wavelet decompositions of
      the generated images.

  Returns:
    A tuple of (samples, reference, color_space, representation), where
    samples = A set of sampled images of size
      (`num_samples`, `image_width`, `image_width`, 3)
    reference = The empirical mean of `samples` (computed in YUV Wavelet space
      but returned as an RGB image) of size (`image_width`, `image_width`, 3).
    color_space = 'YUV'
    representation = 'CDF9/7'
  """
    color_space = 'YUV'
    representation = 'CDF9/7'
    samples = []
    reference = []
    for level in range(wavelet_num_levels):
        samples.append([])
        reference.append([])
        w = image_width // 2**(level + 1)
        scaling = 2**level
        for _ in range(3):
            # Construct the ground-truth pixel band mean.
            mu = scaling * np.random.uniform(size=(3, w, w))
            # Draw samples from the ground-truth mean.
            band_samples = np.random.normal(
                loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1]))
            # Take the empirical mean of the samples as a reference.
            band_reference = np.mean(band_samples, 0)
            samples[-1].append(np.reshape(band_samples, [-1, w, w]))
            reference[-1].append(band_reference)
    # Handle the residual band.
    mu = scaling * np.random.uniform(size=(3, w, w))
    band_samples = np.random.normal(
        loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1]))
    band_reference = np.mean(band_samples, 0)
    samples.append(np.reshape(band_samples, [-1, w, w]))
    reference.append(band_reference)
    # Collapse and reshape wavelets to be ({_,} width, height, 3).
    samples = wavelet.collapse(samples, representation)
    reference = wavelet.collapse(reference, representation)
    samples = tf.transpose(
        tf.reshape(samples, [num_samples, 3, image_width, image_width]),
        [0, 2, 3, 1])
    reference = tf.transpose(reference, [1, 2, 0])
    # Convert into RGB space.
    samples = util.syuv_to_rgb(samples)
    reference = util.syuv_to_rgb(reference)
    with tf.Session() as sess:
        samples, reference = sess.run((samples, reference))
    return samples, reference, color_space, representation