Ejemplo n.º 1
0
    def testAccurateRoundTripWithSmallRandomImages(self):
        """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4]."""
        for wavelet_type in wavelet.generate_filters():
            for width in range(0, 5):
                sz = [1, width, width]
                num_levels = wavelet.get_max_num_levels(sz)
                im = np.random.uniform(size=sz)

                pyr = wavelet.construct(im, num_levels, wavelet_type)
                recon = wavelet.collapse(pyr, wavelet_type)
                np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
Ejemplo n.º 2
0
 def _collapse_preserves_dtype(self, float_dtype):
     """Checks that collapse()'s output has the same precision as its input."""
     n = 16
     x = []
     for n in [8, 4, 2]:
         band = []
         for _ in range(3):
             band.append(float_dtype(np.random.normal(size=(3, n, n))))
         x.append(band)
     x.append(float_dtype(np.random.normal(size=(3, n, n))))
     for wavelet_type in wavelet.generate_filters():
         y = wavelet.collapse(x, wavelet_type)
         np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)
Ejemplo n.º 3
0
 def testAccurateRoundTripWithLargeRandomImages(self):
   """Tests that collapse(construct(x)) == x for large random x's."""
   for wavelet_type in wavelet.generate_filters():
     for _ in range(4):
       num_levels = np.int32(np.ceil(4 * np.random.uniform()))
       sz_clamp = 2**(num_levels - 1) + 1
       sz = np.maximum(
           np.int32(
               np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
           np.array([0, sz_clamp, sz_clamp]))
       im = np.random.uniform(size=sz)
       pyr = wavelet.construct(im, num_levels, wavelet_type)
       recon = wavelet.collapse(pyr, wavelet_type)
       np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
Ejemplo n.º 4
0
  def testCollapseMatchesGoldenData(self, device):
    """Tests collapse() against golden data."""
    im, pyr_true, wavelet_type = self._load_golden_data()

    pyr_true = list(pyr_true)
    for d in range(len(pyr_true) - 1):
      pyr_true[d] = list(pyr_true[d])
      for b in range(3):
        pyr_true[d][b] = torch.tensor(pyr_true[d][b], device=device)
    d = len(pyr_true) - 1
    pyr_true[d] = torch.tensor(pyr_true[d], device=device)

    recon = wavelet.collapse(pyr_true, wavelet_type).cpu().detach()
    np.testing.assert_allclose(recon, im, atol=1e-5, rtol=1e-5)
Ejemplo n.º 5
0
    def testDecompositionIsNonRedundant(self):
        """Test that wavelet construction is not redundant.

    If the wavelet decompositon is not redundant, then we should be able to
    1) Construct a wavelet decomposition
    2) Alter a single coefficient in the decomposition
    3) Collapse that decomposition into an image and back
    and the two wavelet decompositions should be the same.
    """
        for wavelet_type in wavelet.generate_filters():
            for _ in range(4):
                # Construct an image and a wavelet decomposition of it.
                num_levels = np.int32(np.ceil(4 * np.random.uniform()))
                sz_clamp = 2**(num_levels - 1) + 1
                sz = np.maximum(
                    np.int32(
                        np.ceil(
                            np.array([2, 32, 32]) *
                            np.random.uniform(size=3))),
                    np.array([0, sz_clamp, sz_clamp]),
                )
                im = np.random.uniform(size=sz)
                pyr = wavelet.construct(im, num_levels, wavelet_type)

            # Pick a coefficient at random in the decomposition to alter.
            d = np.int32(np.floor(np.random.uniform() * len(pyr)))
            v = np.random.uniform()
            if d == (len(pyr) - 1):
                if np.prod(pyr[d].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d].shape)).tolist()
                    pyr[d][c, i, j] = v
            else:
                b = np.int32(np.floor(np.random.uniform() * len(pyr[d])))
                if np.prod(pyr[d][b].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d][b].shape)).tolist()
                    pyr[d][b][c, i, j] = v

            # Collapse and then reconstruct the wavelet decomposition, and check
            # that it is unchanged.
            recon = wavelet.collapse(pyr, wavelet_type)
            pyr_again = wavelet.construct(recon, num_levels, wavelet_type)
            self._assert_pyramids_close(pyr, pyr_again, 1e-8)
Ejemplo n.º 6
0
 def testCollapseMatchesGoldenData(self):
     """Tests collapse() against golden data."""
     im, pyr_true, wavelet_type = self._load_golden_data()
     recon = wavelet.collapse(pyr_true, wavelet_type)
     np.testing.assert_allclose(recon, im, atol=1e-5, rtol=1e-5)
Ejemplo n.º 7
0
def _generate_wavelet_toy_image_data(image_width, num_samples,
                                     wavelet_num_levels):
  """Generates wavelet data for testFittingImageDataIsCorrect().

  Constructs a "mean" image in the YUV wavelet domain (parametrized by
  `image_width`, and `wavelet_num_levels`) and draws `num_samples` samples
  from a normal distribution using that mean, and returns RGB images
  corresponding to those samples and to the mean (computed in the
  specified latent space) of those samples.

  Args:
    image_width: The width and height in pixels of the images being produced.
    num_samples: The number of samples to generate.
    wavelet_num_levels: The number of levels in the wavelet decompositions of
      the generated images.

  Returns:
    A tuple of (samples, reference, color_space, representation), where
    samples = A set of sampled images of size
      (`num_samples`, `image_width`, `image_width`, 3)
    reference = The empirical mean of `samples` (computed in YUV Wavelet space
      but returned as an RGB image) of size (`image_width`, `image_width`, 3).
    color_space = 'YUV'
    representation = 'CDF9/7'
  """
  color_space = 'YUV'
  representation = 'CDF9/7'
  samples = []
  reference = []
  for level in range(wavelet_num_levels):
    samples.append([])
    reference.append([])
    w = image_width // 2**(level + 1)
    scaling = 2**level
    for _ in range(3):
      # Construct the ground-truth pixel band mean.
      mu = scaling * np.random.uniform(size=(3, w, w))
      # Draw samples from the ground-truth mean.
      band_samples = np.random.normal(
          loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1]))
      # Take the empirical mean of the samples as a reference.
      band_reference = np.mean(band_samples, 0)
      samples[-1].append(np.reshape(band_samples, [-1, w, w]))
      reference[-1].append(band_reference)
  # Handle the residual band.
  mu = scaling * np.random.uniform(size=(3, w, w))
  band_samples = np.random.normal(
      loc=np.tile(mu[np.newaxis], [num_samples, 1, 1, 1]))
  band_reference = np.mean(band_samples, 0)
  samples.append(np.reshape(band_samples, [-1, w, w]))
  reference.append(band_reference)
  # Collapse and reshape wavelets to be ({_,} width, height, 3).
  samples = wavelet.collapse(samples, representation)
  reference = wavelet.collapse(reference, representation)
  samples = np.transpose(
      np.reshape(samples, [num_samples, 3, image_width, image_width]),
      [0, 2, 3, 1])
  reference = np.transpose(reference, [1, 2, 0])
  # Convert into RGB space.
  samples = util.syuv_to_rgb(samples)
  reference = util.syuv_to_rgb(reference)
  samples = samples.detach().numpy()
  reference = reference.detach().numpy()
  return samples, reference, color_space, representation