예제 #1
    def process(self, batch_x, quality=None, return_entropy=False):
        """ Compress a batch of images (NHW3:rgb) with a given quality factor:

        - if quality is a number - use this quality level
        - if quality is an iterable with 2 numbers - use a random integer from that range
        - if quality is an iterable with >2 numbers - use a random value from that set

        quality = self.quality if quality is None else quality
        quality = int(quality)

        if not is_valid_quality(quality):
            raise ValueError('Invalid or unspecified JPEG quality!')

#         if hasattr(quality, '__getitem__') and len(quality) > 2:
#             quality = int(np.random.choice(quality))
#         elif hasattr(quality, '__getitem__') and len(quality) == 2:
#             quality = np.random.randint(quality[0], quality[1])
#         elif is_number(quality) and quality >= 1 and quality <= 100:
#             quality = int(quality)
        #    raise ValueError('Invalid quality! {}'.format(quality))

        if self._model is None:
            if not isinstance(batch_x, np.ndarray):
                batch_x = batch_x.numpy()
            if return_entropy:
                return jpeg_helpers.compress_batch(batch_x, quality)[0], np.nan
                return jpeg_helpers.compress_batch(batch_x, quality)[0]
            if quality != self.quality:
                old_q_luma, old_q_chroma = self._model._q_mtx_luma, self._model._q_mtx_chroma
                self._model._q_mtx_luma = jpeg_qtable(quality, 0)
                self._model._q_mtx_chroma = jpeg_qtable(quality, 1)
            y, X = self._model(batch_x)

            if quality != self.quality:
                self._model._q_mtx_luma, self._model._q_mtx_chroma = old_q_luma, old_q_chroma

            if return_entropy:
                # TODO This currently takes too much memory
                entropy = tf_helpers.entropy(X, self._model.quantization.codebook, v=5, gamma=5)[0]
                return y, X, entropy

            return y
예제 #2
    def train_q_table(self, batch_x, alpha, beta, n_times):
        # set the training point at self.quality of standard jpeg quantization table 
        q_mtx_luma_init = jpeg_qtable(self.quality, 0)
        q_mtx_chroma_init = jpeg_qtable(self.quality, 1)
        self._model._q_mtx_luma = self._model.add_weight('Q_mtx_luma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_luma_init))
        self._model._q_mtx_chroma = self._model.add_weight('Q_mtx_chroma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_chroma_init))
        print("quantization before training: ", self._model._q_mtx_luma, self._model._q_mtx_chroma)
        opt = tf.keras.optimizers.SGD(5)
        target_entropy = 0
        for epoch in range(0, n_times+1):
            with tf.GradientTape() as tape1:
                batch_y, Z, entropy_code = self.process(batch_x, return_entropy=True)
                ssim = tf.reduce_mean(tf.image.ssim(tf.convert_to_tensor(255.0*batch_y), tf.convert_to_tensor(255.0*batch_x), max_val=255))
                ssim_loss = 1-ssim
                distortion = tf.reduce_mean(tf.math.pow((255.0*batch_x-255.0*batch_y),2))
                if epoch==0:
                    target_entropy = entropy_code
                if entropy_code<target_entropy:
                    loss = alpha*distortion + beta*(entropy_code-target_entropy)
                    loss = alpha*distortion + beta*(100*(entropy_code-target_entropy))**2
            grad = tape1.gradient(loss, self.parameters)
            grad = [x/(1e-12 + tf.linalg.norm(x)) for x in grad]

            opt.apply_gradients(zip(grad, self.parameters))
            if epoch==n_times:
                self._model._q_mtx_luma = tf.clip_by_value(tf.round(self._model._q_mtx_luma), clip_value_min=1, clip_value_max=256)
                self._model._q_mtx_chroma = tf.clip_by_value(tf.round(self._model._q_mtx_chroma),clip_value_min=1, clip_value_max=256)
            print(f'{epoch:03d} --> {ssim:.2f} {loss:.2f} {distortion:.2f} {entropy_code:.2f} {qf} {qt[0, 0]} {qt[0, 1]}')
        print("quantization after training: ", self._model._q_mtx_luma, self._model._q_mtx_chroma)
        return  self._model._q_mtx_luma, self._model._q_mtx_chroma
예제 #3
    def __init__(self, quality=None, rounding_approximation='sin', rounding_approximation_steps=5, trainable=False):

        if quality is not None and not is_valid_quality(quality):
            raise ValueError('Invalid JPEG quality: requires int in [1,100] or an iterable with least 2 such numbers')

        # Sanitize inputs
        if rounding_approximation is not None and rounding_approximation not in ['sin', 'harmonic', 'soft']:
            raise ValueError('Unsupported rounding approximation: {}'.format(rounding_approximation))

        # Quantization tables
        if trainable:
            q_mtx_luma_init = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 0)
            q_mtx_chroma_init = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 1)
            self._q_mtx_luma = self.add_weight('Q_mtx_luma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_luma_init))
            self._q_mtx_chroma = self.add_weight('Q_mtx_chroma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_chroma_init))
            self._q_mtx_luma = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 0)
            self._q_mtx_chroma = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 1)

        # Parameters
        self.quality = quality
        self.trainable = trainable
        self.rounding_approximation = rounding_approximation
        self.rounding_approximation_steps = rounding_approximation_steps

        # RGB to YCbCr conversion
        self._color_F = np.array([[0, 0.299, 0.587, 0.114], [128, -0.168736, -0.331264, 0.5], [128, 0.5, -0.418688, -0.081312]], dtype=np.float32)
        self._color_I = np.array([[-1.402 * 128, 1, 0, 1.402], [1.058272 * 128, 1, -0.344136, -0.714136], [-1.772 * 128, 1, 1.772, 0]], dtype=np.float32)
        # DCT
        self._dct_F = np.array([[0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536],
                                [0.4904, 0.4157, 0.2778, 0.0975, -0.0975, -0.2778, -0.4157, -0.4904],
                                [0.4619, 0.1913, -0.1913, -0.4619, -0.4619, -0.1913, 0.1913, 0.4619],
                                [0.4157, -0.0975, -0.4904, -0.2778, 0.2778, 0.4904, 0.0975, -0.4157],
                                [0.3536, -0.3536, -0.3536, 0.3536, 0.3536, -0.3536, -0.3536, 0.3536],
                                [0.2778, -0.4904, 0.0975, 0.4157, -0.4157, -0.0975, 0.4904, -0.2778],
                                [0.1913, -0.4619, 0.4619, -0.1913, -0.1913, 0.4619, -0.4619, 0.1913],
                                [0.0975, -0.2778, 0.4157, -0.4904, 0.4904, -0.4157, 0.2778, -0.0975]], dtype=np.float32)
        self._dct_I = np.transpose(self._dct_F)
        # Quantization layer
        self.quantization = Quantization(self.rounding_approximation, self.rounding_approximation_steps, latent_bpf=9)