def testImageDctIsOrthonormal(self): """Test that <im0, im1> = <image_dct(im0), image_dct(im1)>.""" for _ in range(4): im0 = np.float32(np.random.uniform(size=(4, 4, 2))) im1 = np.float32(np.random.uniform(size=(4, 4, 2))) dct_im0 = util.image_dct(im0) dct_im1 = util.image_dct(im1) prod1 = tf.reduce_sum(im0 * im1) prod2 = tf.reduce_sum(dct_im0 * dct_im1) self.assertAllClose(prod1, prod2)
def transform_to_mat(self, x): """Transforms a batch of images to a matrix.""" assert len(x.shape) == 4 x = torch.as_tensor(x) if self.color_space == 'YUV': x = util.rgb_to_syuv(x) # If `color_space` == 'RGB', do nothing. # Reshape `x` from # (num_batches, width, height, num_channels) to # (num_batches * num_channels, width, height) _, width, height, num_channels = x.shape x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height)) # Turn each channel in `x_stack` into the spatial representation specified # by `representation`. if self.representation in wavelet.generate_filters(): x_stack = wavelet.flatten( wavelet.rescale( wavelet.construct(x_stack, self.wavelet_num_levels, self.representation), self.wavelet_scale_base)) elif self.representation == 'DCT': x_stack = util.image_dct(x_stack) # If `representation` == 'PIXEL', do nothing. # Reshape `x_stack` from # (num_batches * num_channels, width, height) to # (num_batches, num_channels * width * height) x_mat = torch.reshape( torch.reshape(x_stack, (-1, num_channels, width, height)).permute( 0, 2, 3, 1), [-1, width * height * num_channels]) return x_mat
def testImageDctIsCorrect(self): """Tests that image_dct() is volume preserving.""" im = np.float32(np.random.uniform(size=(3, 8, 8))) im_dct = util.image_dct(im) dct_y = tf.transpose(tf.spectral.dct(im, type=2, norm='ortho'), [0, 2, 1]) dct_x = tf.transpose(tf.spectral.dct(dct_y, type=2, norm='ortho'), [0, 2, 1]) im_dct_true = dct_x self.assertAllClose(im_dct, im_dct_true, atol=1e-5, rtol=1e-5)
def testImageDctRoundTrip(self): """Tests that image_idct(image_dct(x)) == x.""" image = np.float32(np.random.uniform(size=(32, 32, 3))) image_recon = util.image_idct(util.image_dct(image)) self.assertAllClose(image, image_recon)