Beispiel #1
0
 def testImageDctIsOrthonormal(self):
     """Test that <im0, im1> = <image_dct(im0), image_dct(im1)>."""
     for _ in range(4):
         im0 = np.float32(np.random.uniform(size=(4, 4, 2)))
         im1 = np.float32(np.random.uniform(size=(4, 4, 2)))
         dct_im0 = util.image_dct(im0)
         dct_im1 = util.image_dct(im1)
         prod1 = tf.reduce_sum(im0 * im1)
         prod2 = tf.reduce_sum(dct_im0 * dct_im1)
         self.assertAllClose(prod1, prod2)
Beispiel #2
0
    def transform_to_mat(self, x):
        """Transforms a batch of images to a matrix."""
        assert len(x.shape) == 4
        x = torch.as_tensor(x)
        if self.color_space == 'YUV':
            x = util.rgb_to_syuv(x)
        # If `color_space` == 'RGB', do nothing.

        # Reshape `x` from
        #   (num_batches, width, height, num_channels) to
        #   (num_batches * num_channels, width, height)
        _, width, height, num_channels = x.shape
        x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height))

        # Turn each channel in `x_stack` into the spatial representation specified
        # by `representation`.
        if self.representation in wavelet.generate_filters():
            x_stack = wavelet.flatten(
                wavelet.rescale(
                    wavelet.construct(x_stack, self.wavelet_num_levels,
                                      self.representation),
                    self.wavelet_scale_base))
        elif self.representation == 'DCT':
            x_stack = util.image_dct(x_stack)
        # If `representation` == 'PIXEL', do nothing.

        # Reshape `x_stack` from
        #   (num_batches * num_channels, width, height) to
        #   (num_batches, num_channels * width * height)
        x_mat = torch.reshape(
            torch.reshape(x_stack, (-1, num_channels, width, height)).permute(
                0, 2, 3, 1), [-1, width * height * num_channels])
        return x_mat
Beispiel #3
0
    def testImageDctIsCorrect(self):
        """Tests that image_dct() is volume preserving."""
        im = np.float32(np.random.uniform(size=(3, 8, 8)))

        im_dct = util.image_dct(im)

        dct_y = tf.transpose(tf.spectral.dct(im, type=2, norm='ortho'),
                             [0, 2, 1])
        dct_x = tf.transpose(tf.spectral.dct(dct_y, type=2, norm='ortho'),
                             [0, 2, 1])
        im_dct_true = dct_x

        self.assertAllClose(im_dct, im_dct_true, atol=1e-5, rtol=1e-5)
Beispiel #4
0
 def testImageDctRoundTrip(self):
     """Tests that image_idct(image_dct(x)) == x."""
     image = np.float32(np.random.uniform(size=(32, 32, 3)))
     image_recon = util.image_idct(util.image_dct(image))
     self.assertAllClose(image, image_recon)