Exemple #1
0
    def testDecompositionIsNonRedundant(self):
        """Test that wavelet construction is not redundant.

    If the wavelet decompositon is not redundant, then we should be able to
    1) Construct a wavelet decomposition
    2) Alter a single coefficient in the decomposition
    3) Collapse that decomposition into an image and back
    and the two wavelet decompositions should be the same.
    """
        for wavelet_type in wavelet.generate_filters():
            for _ in range(4):
                # Construct an image and a wavelet decomposition of it.
                num_levels = np.int32(np.ceil(4 * np.random.uniform()))
                sz_clamp = 2**(num_levels - 1) + 1
                sz = np.maximum(
                    np.int32(
                        np.ceil(
                            np.array([2, 32, 32]) *
                            np.random.uniform(size=3))),
                    np.array([0, sz_clamp, sz_clamp]),
                )
                im = np.random.uniform(size=sz)
                pyr = wavelet.construct(im, num_levels, wavelet_type)

            # Pick a coefficient at random in the decomposition to alter.
            d = np.int32(np.floor(np.random.uniform() * len(pyr)))
            v = np.random.uniform()
            if d == (len(pyr) - 1):
                if np.prod(pyr[d].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d].shape)).tolist()
                    pyr[d][c, i, j] = v
            else:
                b = np.int32(np.floor(np.random.uniform() * len(pyr[d])))
                if np.prod(pyr[d][b].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d][b].shape)).tolist()
                    pyr[d][b][c, i, j] = v

            # Collapse and then reconstruct the wavelet decomposition, and check
            # that it is unchanged.
            recon = wavelet.collapse(pyr, wavelet_type)
            pyr_again = wavelet.construct(recon, num_levels, wavelet_type)
            self._assert_pyramids_close(pyr, pyr_again, 1e-8)
Exemple #2
0
    def transform_to_mat(self, x):
        """Transforms a batch of images to a matrix."""
        assert len(x.shape) == 4
        x = torch.as_tensor(x)
        if self.color_space == 'YUV':
            x = util.rgb_to_syuv(x)
        # If `color_space` == 'RGB', do nothing.

        # Reshape `x` from
        #   (num_batches, width, height, num_channels) to
        #   (num_batches * num_channels, width, height)
        _, width, height, num_channels = x.shape
        x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height))

        # Turn each channel in `x_stack` into the spatial representation specified
        # by `representation`.
        if self.representation in wavelet.generate_filters():
            x_stack = wavelet.flatten(
                wavelet.rescale(
                    wavelet.construct(x_stack, self.wavelet_num_levels,
                                      self.representation),
                    self.wavelet_scale_base))
        elif self.representation == 'DCT':
            x_stack = util.image_dct(x_stack)
        # If `representation` == 'PIXEL', do nothing.

        # Reshape `x_stack` from
        #   (num_batches * num_channels, width, height) to
        #   (num_batches, num_channels * width * height)
        x_mat = torch.reshape(
            torch.reshape(x_stack, (-1, num_channels, width, height)).permute(
                0, 2, 3, 1), [-1, width * height * num_channels])
        return x_mat
Exemple #3
0
 def testRescaleAndUnrescaleReproducesInput(self):
     """Tests that rescale(rescale(x, k), 1/k) = x."""
     im = np.random.uniform(size=(2, 32, 32))
     scale_base = np.exp(np.random.normal())
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, scale_base)
     pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base)
     self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
Exemple #4
0
 def testRescaleOneHalfIsNormalized(self):
     """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k."""
     for num_levels in range(5):
         k = np.random.uniform()
         im = k * np.ones((2, 32, 32))
         pyr = wavelet.construct(im, num_levels, 'LeGall5/3')
         pyr_rescaled = wavelet.rescale(pyr, 0.5)
         np.testing.assert_allclose(pyr_rescaled[-1],
                                    k * np.ones_like(pyr_rescaled[-1]),
                                    atol=1e-8,
                                    rtol=1e-8)
Exemple #5
0
    def testAccurateRoundTripWithSmallRandomImages(self):
        """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4]."""
        for wavelet_type in wavelet.generate_filters():
            for width in range(0, 5):
                sz = [1, width, width]
                num_levels = wavelet.get_max_num_levels(sz)
                im = np.random.uniform(size=sz)

                pyr = wavelet.construct(im, num_levels, wavelet_type)
                recon = wavelet.collapse(pyr, wavelet_type)
                np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
Exemple #6
0
 def testAccurateRoundTripWithLargeRandomImages(self):
   """Tests that collapse(construct(x)) == x for large random x's."""
   for wavelet_type in wavelet.generate_filters():
     for _ in range(4):
       num_levels = np.int32(np.ceil(4 * np.random.uniform()))
       sz_clamp = 2**(num_levels - 1) + 1
       sz = np.maximum(
           np.int32(
               np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
           np.array([0, sz_clamp, sz_clamp]))
       im = np.random.uniform(size=sz)
       pyr = wavelet.construct(im, num_levels, wavelet_type)
       recon = wavelet.collapse(pyr, wavelet_type)
       np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
Exemple #7
0
  def testConstructMatchesGoldenData(self, device):
    """Tests construct() against golden data."""
    im, pyr_true, wavelet_type = self._load_golden_data()
    im = torch.tensor(im, device=device)
    pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type)

    pyr = list(pyr)
    for d in range(len(pyr) - 1):
      pyr[d] = list(pyr[d])
      for b in range(3):
        pyr[d][b] = pyr[d][b].cpu().detach()
    d = len(pyr) - 1
    pyr[d] = pyr[d].cpu().detach()

    self._assert_pyramids_close(pyr, pyr_true, 1e-5)
Exemple #8
0
 def testWaveletTransformationIsVolumePreserving(self):
   """Tests that construct() is volume preserving when size is a power of 2."""
   for wavelet_type in wavelet.generate_filters():
     sz = (1, 4, 4)
     num_levels = 2
     # Construct the Jacobian of construct().
     im = np.float32(np.random.uniform(0., 1., sz))
     jacobian = []
     vec = lambda x: torch.reshape(x, [-1])
     for d in range(im.size):
       var_im = torch.autograd.Variable(torch.tensor(im), requires_grad=True)
       coeff = vec(
           wavelet.flatten(
               wavelet.construct(var_im, num_levels, wavelet_type)))[d]
       coeff.backward()
       jacobian.append(np.reshape(var_im.grad.detach().numpy(), [-1]))
     jacobian = np.stack(jacobian, 1)
     # Assert that the determinant of the Jacobian is close to 1.
     det = np.linalg.det(jacobian)
     np.testing.assert_allclose(det, 1., atol=1e-5, rtol=1e-5)
Exemple #9
0
 def testRescaleDoesNotAffectTheFirstLevel(self):
     """Tests that rescale(x, s)[0] = x[0] for any s."""
     im = np.random.uniform(size=(2, 32, 32))
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal()))
     self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
Exemple #10
0
 def testRescaleOneIsANoOp(self):
     """Tests that rescale(x, 1) = x."""
     im = np.random.uniform(size=(2, 32, 32))
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, 1.)
     self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)
Exemple #11
0
 def _construct_preserves_dtype(self, float_dtype):
     """Checks that construct()'s output has the same precision as its input."""
     x = float_dtype(np.random.normal(size=(3, 16, 16)))
     for wavelet_type in wavelet.generate_filters():
         y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type))
         np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)
Exemple #12
0
 def testConstructMatchesGoldenData(self):
     """Tests construct() against golden data."""
     im, pyr_true, wavelet_type = self._load_golden_data()
     pyr = wavelet.construct(im, len(pyr_true) - 1, wavelet_type)
     self._assert_pyramids_close(pyr, pyr_true, 1e-5)