示例#1
0
 def testRescaleAndUnrescaleReproducesInput(self):
     """Tests that rescale(rescale(x, k), 1/k) = x."""
     im = np.random.uniform(size=(2, 32, 32))
     scale_base = np.exp(np.random.normal())
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, scale_base)
     pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base)
     self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
示例#2
0
    def transform_to_mat(self, x):
        """Transforms a batch of images to a matrix."""
        assert len(x.shape) == 4
        x = torch.as_tensor(x)
        if self.color_space == 'YUV':
            x = util.rgb_to_syuv(x)
        # If `color_space` == 'RGB', do nothing.

        # Reshape `x` from
        #   (num_batches, width, height, num_channels) to
        #   (num_batches * num_channels, width, height)
        _, width, height, num_channels = x.shape
        x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height))

        # Turn each channel in `x_stack` into the spatial representation specified
        # by `representation`.
        if self.representation in wavelet.generate_filters():
            x_stack = wavelet.flatten(
                wavelet.rescale(
                    wavelet.construct(x_stack, self.wavelet_num_levels,
                                      self.representation),
                    self.wavelet_scale_base))
        elif self.representation == 'DCT':
            x_stack = util.image_dct(x_stack)
        # If `representation` == 'PIXEL', do nothing.

        # Reshape `x_stack` from
        #   (num_batches * num_channels, width, height) to
        #   (num_batches, num_channels * width * height)
        x_mat = torch.reshape(
            torch.reshape(x_stack, (-1, num_channels, width, height)).permute(
                0, 2, 3, 1), [-1, width * height * num_channels])
        return x_mat
示例#3
0
 def testRescaleOneHalfIsNormalized(self):
     """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k."""
     for num_levels in range(5):
         k = np.random.uniform()
         im = k * np.ones((2, 32, 32))
         pyr = wavelet.construct(im, num_levels, 'LeGall5/3')
         pyr_rescaled = wavelet.rescale(pyr, 0.5)
         np.testing.assert_allclose(pyr_rescaled[-1],
                                    k * np.ones_like(pyr_rescaled[-1]),
                                    atol=1e-8,
                                    rtol=1e-8)
示例#4
0
 def testRescaleDoesNotAffectTheFirstLevel(self):
     """Tests that rescale(x, s)[0] = x[0] for any s."""
     im = np.random.uniform(size=(2, 32, 32))
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal()))
     self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
示例#5
0
 def testRescaleOneIsANoOp(self):
     """Tests that rescale(x, 1) = x."""
     im = np.random.uniform(size=(2, 32, 32))
     pyr = wavelet.construct(im, 4, 'LeGall5/3')
     pyr_rescaled = wavelet.rescale(pyr, 1.)
     self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)