def testWikiExpected(self): test_img = self.test_wiki_gray expected, _ = create_single_batch(self.test_wiki_expected) with self.cached_session(use_gpu=True): dct_transform = dct.DCT2D(n=8) x = center_img(test_img) x = dct_transform(x) x = tf.reshape(x, expected.shape) self.assertAllClose(x, expected, atol=1e-2)
def testEncodingDecodingGrayBatch(self): test_img = self.test_wiki_gray_batch with self.cached_session(use_gpu=True): dct_transform = dct.DCT2D(n=8) x = center_img(test_img) channels = x.shape[-1] x = dct_transform(x) x = tf.stack(tf.split(x, channels, axis=-1), -1) y = tf.concat(tf.split(x, channels, axis=-1), -2) y = tf.squeeze(y, -1) dct_inverse = dct.InverseDCT2D(n=8) y = dct_inverse(y) y = de_center_img(y) self.assertAllLessEqual(tf.subtract(y, test_img), 1e-3)
def __init__(self, n, quality=50, **kwargs): super(JPEG_Mask, self).__init__(trainable=False, **kwargs) self.dct_transform = dct.DCT2D(n=n) self.dct_inverse = dct.InverseDCT2D(n=n) self.masks = self._create_masks(n, quality) self.n = n