def process(self, batch_x, quality=None, return_entropy=False): """ Compress a batch of images (NHW3:rgb) with a given quality factor: - if quality is a number - use this quality level - if quality is an iterable with 2 numbers - use a random integer from that range - if quality is an iterable with >2 numbers - use a random value from that set """ quality = self.quality if quality is None else quality quality = int(quality) if not is_valid_quality(quality): raise ValueError('Invalid or unspecified JPEG quality!') # if hasattr(quality, '__getitem__') and len(quality) > 2: # quality = int(np.random.choice(quality)) # elif hasattr(quality, '__getitem__') and len(quality) == 2: # quality = np.random.randint(quality[0], quality[1]) # elif is_number(quality) and quality >= 1 and quality <= 100: # quality = int(quality) #else: # raise ValueError('Invalid quality! {}'.format(quality)) if self._model is None: if not isinstance(batch_x, np.ndarray): batch_x = batch_x.numpy() if return_entropy: return jpeg_helpers.compress_batch(batch_x, quality)[0], np.nan else: return jpeg_helpers.compress_batch(batch_x, quality)[0] else: if quality != self.quality: old_q_luma, old_q_chroma = self._model._q_mtx_luma, self._model._q_mtx_chroma self._model._q_mtx_luma = jpeg_qtable(quality, 0) self._model._q_mtx_chroma = jpeg_qtable(quality, 1) y, X = self._model(batch_x) if quality != self.quality: self._model._q_mtx_luma, self._model._q_mtx_chroma = old_q_luma, old_q_chroma if return_entropy: # TODO This currently takes too much memory entropy = tf_helpers.entropy(X, self._model.quantization.codebook, v=5, gamma=5)[0] return y, X, entropy return y
def train_q_table(self, batch_x, alpha, beta, n_times): # set the training point at self.quality of standard jpeg quantization table q_mtx_luma_init = jpeg_qtable(self.quality, 0) q_mtx_chroma_init = jpeg_qtable(self.quality, 1) self._model._q_mtx_luma = self._model.add_weight('Q_mtx_luma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_luma_init)) self._model._q_mtx_chroma = self._model.add_weight('Q_mtx_chroma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_chroma_init)) print("quantization before training: ", self._model._q_mtx_luma, self._model._q_mtx_chroma) opt = tf.keras.optimizers.SGD(5) target_entropy = 0 for epoch in range(0, n_times+1): with tf.GradientTape() as tape1: batch_y, Z, entropy_code = self.process(batch_x, return_entropy=True) ssim = tf.reduce_mean(tf.image.ssim(tf.convert_to_tensor(255.0*batch_y), tf.convert_to_tensor(255.0*batch_x), max_val=255)) ssim_loss = 1-ssim distortion = tf.reduce_mean(tf.math.pow((255.0*batch_x-255.0*batch_y),2)) if epoch==0: target_entropy = entropy_code if entropy_code<target_entropy: loss = alpha*distortion + beta*(entropy_code-target_entropy) else: loss = alpha*distortion + beta*(100*(entropy_code-target_entropy))**2 grad = tape1.gradient(loss, self.parameters) grad = [x/(1e-12 + tf.linalg.norm(x)) for x in grad] opt.apply_gradients(zip(grad, self.parameters)) if epoch==n_times: self._model._q_mtx_luma = tf.clip_by_value(tf.round(self._model._q_mtx_luma), clip_value_min=1, clip_value_max=256) self._model._q_mtx_chroma = tf.clip_by_value(tf.round(self._model._q_mtx_chroma),clip_value_min=1, clip_value_max=256) qf=self.estimate_qf(0) qt=self._model._q_mtx_luma print(f'{epoch:03d} --> {ssim:.2f} {loss:.2f} {distortion:.2f} {entropy_code:.2f} {qf} {qt[0, 0]} {qt[0, 1]}') print("quantization after training: ", self._model._q_mtx_luma, self._model._q_mtx_chroma) return self._model._q_mtx_luma, self._model._q_mtx_chroma
def __init__(self, quality=None, rounding_approximation='sin', rounding_approximation_steps=5, trainable=False): super().__init__(self) if quality is not None and not is_valid_quality(quality): raise ValueError('Invalid JPEG quality: requires int in [1,100] or an iterable with least 2 such numbers') # Sanitize inputs if rounding_approximation is not None and rounding_approximation not in ['sin', 'harmonic', 'soft']: raise ValueError('Unsupported rounding approximation: {}'.format(rounding_approximation)) # Quantization tables if trainable: q_mtx_luma_init = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 0) q_mtx_chroma_init = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 1) self._q_mtx_luma = self.add_weight('Q_mtx_luma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_luma_init)) self._q_mtx_chroma = self.add_weight('Q_mtx_chroma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_chroma_init)) else: self._q_mtx_luma = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 0) self._q_mtx_chroma = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 1) # Parameters self.quality = quality self.trainable = trainable self.rounding_approximation = rounding_approximation self.rounding_approximation_steps = rounding_approximation_steps # RGB to YCbCr conversion self._color_F = np.array([[0, 0.299, 0.587, 0.114], [128, -0.168736, -0.331264, 0.5], [128, 0.5, -0.418688, -0.081312]], dtype=np.float32) self._color_I = np.array([[-1.402 * 128, 1, 0, 1.402], [1.058272 * 128, 1, -0.344136, -0.714136], [-1.772 * 128, 1, 1.772, 0]], dtype=np.float32) # DCT self._dct_F = np.array([[0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536], [0.4904, 0.4157, 0.2778, 0.0975, -0.0975, -0.2778, -0.4157, -0.4904], [0.4619, 0.1913, -0.1913, -0.4619, -0.4619, -0.1913, 0.1913, 0.4619], [0.4157, -0.0975, -0.4904, -0.2778, 0.2778, 0.4904, 0.0975, -0.4157], [0.3536, -0.3536, -0.3536, 0.3536, 0.3536, -0.3536, -0.3536, 0.3536], [0.2778, -0.4904, 0.0975, 0.4157, -0.4157, -0.0975, 0.4904, -0.2778], [0.1913, -0.4619, 0.4619, -0.1913, -0.1913, 0.4619, -0.4619, 0.1913], [0.0975, -0.2778, 0.4157, -0.4904, 0.4904, -0.4157, 0.2778, -0.0975]], dtype=np.float32) self._dct_I = np.transpose(self._dct_F) # Quantization layer self.quantization = Quantization(self.rounding_approximation, self.rounding_approximation_steps, latent_bpf=9)