예제 #1
0
  def testDecompositionIsNonRedundant(self):
    """Test that wavelet construction is not redundant.

    If the wavelet decompositon is not redundant, then we should be able to
    1) Construct a wavelet decomposition
    2) Alter a single coefficient in the decomposition
    3) Collapse that decomposition into an image and back
    and the two wavelet decompositions should be the same.
    """
    for wavelet_type in wavelet.generate_filters():
      for _ in range(4):
        # Construct an image and a wavelet decomposition of it.
        num_levels = np.int32(np.ceil(4 * np.random.uniform()))
        sz_clamp = 2**(num_levels - 1) + 1
        sz = np.maximum(
            np.int32(
                np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
            np.array([0, sz_clamp, sz_clamp]))
        im = np.random.uniform(size=sz)
        pyr = wavelet.construct(im, num_levels, wavelet_type)
        pyr = list(pyr)

      # Pick a coefficient at random in the decomposition to alter.
      d = np.int32(np.floor(np.random.uniform() * len(pyr)))
      v = np.random.uniform()
      if d == (len(pyr) - 1):
        if np.prod(pyr[d].shape) > 0:
          c, i, j = np.int32(
              np.floor(np.array(np.random.uniform(size=3)) *
                       pyr[d].shape)).tolist()
          pyr[d] = pyr[d].numpy()
          pyr[d][c, i, j] = v
      else:
        b = np.int32(np.floor(np.random.uniform() * len(pyr[d])))
        if np.prod(pyr[d][b].shape) > 0:
          c, i, j = np.int32(
              np.floor(np.array(np.random.uniform(size=3)) *
                       pyr[d][b].shape)).tolist()
          pyr[d] = list(pyr[d])
          pyr[d][b] = pyr[d][b].numpy()
          pyr[d][b][c, i, j] = v

      # Collapse and then reconstruct the wavelet decomposition, and check
      # that it is unchanged.
      recon = wavelet.collapse(pyr, wavelet_type)
      pyr_again = wavelet.construct(recon, num_levels, wavelet_type)
      self._assert_pyramids_close(pyr, pyr_again, 1e-8)
예제 #2
0
 def testRescaleAndUnrescaleReproducesInput(self):
   """Tests that rescale(rescale(x, k), 1/k) = x."""
   im = np.random.uniform(size=(2, 32, 32))
   scale_base = np.exp(np.random.normal())
   pyr = wavelet.construct(im, 4, 'LeGall5/3')
   pyr_rescaled = wavelet.rescale(pyr, scale_base)
   pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base)
   self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
예제 #3
0
  def testAccurateRoundTripWithSmallRandomImages(self):
    """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4]."""
    for wavelet_type in wavelet.generate_filters():
      for width in range(1, 5):
        sz = [1, width, width]
        num_levels = wavelet.get_max_num_levels(sz)
        im = np.random.uniform(size=sz)

        pyr = wavelet.construct(im, num_levels, wavelet_type)
        recon = wavelet.collapse(pyr, wavelet_type)
        self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
예제 #4
0
 def testRescaleOneHalfIsNormalized(self):
   """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k."""
   for num_levels in range(5):
     k = np.random.uniform()
     im = k * np.ones((2, 32, 32))
     pyr = wavelet.construct(im, num_levels, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, 0.5)
     self.assertAllClose(
         pyr_rescaled[-1],
         k * np.ones_like(pyr_rescaled[-1]),
         atol=1e-8,
         rtol=1e-8)
예제 #5
0
 def testAccurateRoundTripWithLargeRandomImages(self):
   """Tests that collapse(construct(x)) == x for large random x's."""
   for wavelet_type in wavelet.generate_filters():
     for _ in range(4):
       num_levels = np.int32(np.ceil(4 * np.random.uniform()))
       sz_clamp = 2**(num_levels - 1) + 1
       sz = np.maximum(
           np.int32(
               np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
           np.array([0, sz_clamp, sz_clamp]))
       im = np.random.uniform(size=sz)
       pyr = wavelet.construct(im, num_levels, wavelet_type)
       recon = wavelet.collapse(pyr, wavelet_type)
       self.assertAllClose(recon, im, atol=1e-8, rtol=1e-8)
예제 #6
0
 def testRescaleDoesNotAffectTheFirstLevel(self):
   """Tests that rescale(x, s)[0] = x[0] for any s."""
   im = np.random.uniform(size=(2, 32, 32))
   pyr = wavelet.construct(im, 4, 'LeGall5/3')
   pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal()))
   self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
예제 #7
0
 def testRescaleOneIsANoOp(self):
   """Tests that rescale(x, 1) = x."""
   im = np.random.uniform(size=(2, 32, 32))
   pyr = wavelet.construct(im, 4, 'LeGall5/3')
   pyr_rescaled = wavelet.rescale(pyr, 1.)
   self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)
예제 #8
0
 def testConstructPreservesDtype(self, float_dtype):
   """Checks that construct()'s output has the same precision as its input."""
   x = float_dtype(np.random.normal(size=(3, 16, 16)))
   for wavelet_type in wavelet.generate_filters():
     y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type))
     self.assertDTypeEqual(y, float_dtype)
예제 #9
0
 def testConstructMatchesGoldenData(self):
   """Tests construct() against golden data."""
   im, pyr_true, wavelet_type = self._load_golden_data()
   pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type)
   self._assert_pyramids_close(pyr, pyr_true, 1e-5)
예제 #10
0
 def fun(z):
   # pylint: disable=cell-var-from-loop
   return wavelet.flatten(wavelet.construct(z, num_levels, wavelet_type))