def is_valid_quality(quality): if is_number(quality) and 1 <= quality <= 100: return True elif hasattr(quality, '__getitem__') and len(quality) > 1 and all( (1 <= x <= 100) for x in quality): return True return False
def to_json(self): params = self.to_dict() params = { k: v if utils.is_number(v) else str(v) for k, v in params.items() } return params
def log_metric(self, metric, scope, value, raw=False): if not raw: if utils.is_number(value): value = float(value) else: value = float(np.mean(value)) self.performance[metric][scope].append(value)
def __init__(self, quality=None, rounding_approximation='sin', rounding_approximation_steps=5, trainable=False): super().__init__(self) if quality is not None and not is_valid_quality(quality): raise ValueError('Invalid JPEG quality: requires int in [1,100] or an iterable with least 2 such numbers') # Sanitize inputs if rounding_approximation is not None and rounding_approximation not in ['sin', 'harmonic', 'soft']: raise ValueError('Unsupported rounding approximation: {}'.format(rounding_approximation)) # Quantization tables if trainable: q_mtx_luma_init = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 0) q_mtx_chroma_init = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 1) self._q_mtx_luma = self.add_weight('Q_mtx_luma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_luma_init)) self._q_mtx_chroma = self.add_weight('Q_mtx_chroma', [8, 8], dtype=tf.float32, initializer=tf.constant_initializer(q_mtx_chroma_init)) else: self._q_mtx_luma = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 0) self._q_mtx_chroma = np.ones((8, 8), dtype=np.float32) if not is_number(quality) else jpeg_qtable(quality, 1) # Parameters self.quality = quality self.trainable = trainable self.rounding_approximation = rounding_approximation self.rounding_approximation_steps = rounding_approximation_steps # RGB to YCbCr conversion self._color_F = np.array([[0, 0.299, 0.587, 0.114], [128, -0.168736, -0.331264, 0.5], [128, 0.5, -0.418688, -0.081312]], dtype=np.float32) self._color_I = np.array([[-1.402 * 128, 1, 0, 1.402], [1.058272 * 128, 1, -0.344136, -0.714136], [-1.772 * 128, 1, 1.772, 0]], dtype=np.float32) # DCT self._dct_F = np.array([[0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536, 0.3536], [0.4904, 0.4157, 0.2778, 0.0975, -0.0975, -0.2778, -0.4157, -0.4904], [0.4619, 0.1913, -0.1913, -0.4619, -0.4619, -0.1913, 0.1913, 0.4619], [0.4157, -0.0975, -0.4904, -0.2778, 0.2778, 0.4904, 0.0975, -0.4157], [0.3536, -0.3536, -0.3536, 0.3536, 0.3536, -0.3536, -0.3536, 0.3536], [0.2778, -0.4904, 0.0975, 0.4157, -0.4157, -0.0975, 0.4904, -0.2778], [0.1913, -0.4619, 0.4619, -0.1913, -0.1913, 0.4619, -0.4619, 0.1913], [0.0975, -0.2778, 0.4157, -0.4904, 0.4904, -0.4157, 0.2778, -0.0975]], dtype=np.float32) self._dct_I = np.transpose(self._dct_F) # Quantization layer self.quantization = Quantization(self.rounding_approximation, self.rounding_approximation_steps, latent_bpf=9)
def process(self, batch_x, quality=None, return_entropy=False): """ Compress a batch of images (NHW3:rgb) with a given quality factor: - if quality is a number - use this quality level - if quality is an iterable with 2 numbers - use a random integer from that range - if quality is an iterable with >2 numbers - use a random value from that set """ quality = self.quality if quality is None else quality if not is_valid_quality(quality): raise ValueError('Invalid or unspecified JPEG quality!') if hasattr(quality, '__getitem__') and len(quality) > 2: quality = int(np.random.choice(quality)) elif hasattr(quality, '__getitem__') and len(quality) == 2: quality = np.random.randint(quality[0], quality[1]) elif is_number(quality) and quality >= 1 and quality <= 100: quality = int(quality) else: raise ValueError('Invalid quality! {}'.format(quality)) if self._model is None: if not isinstance(batch_x, np.ndarray): batch_x = batch_x.numpy() if return_entropy: return jpeg_helpers.compress_batch(batch_x, quality)[0], np.nan else: return jpeg_helpers.compress_batch(batch_x, quality)[0] else: if quality != self.quality: old_q_luma, old_q_chroma = self._model._q_mtx_luma, self._model._q_mtx_chroma self._model._q_mtx_luma = jpeg_qtable(quality, 0) self._model._q_mtx_chroma = jpeg_qtable(quality, 1) y, X = self._model(batch_x) if quality != self.quality: self._model._q_mtx_luma, self._model._q_mtx_chroma = old_q_luma, old_q_chroma if return_entropy: # TODO This currently takes too much memory # entropy = tf_helpers.entropy(X, self._model.quantization.codebook)[0] entropy = np.nan return y, entropy return y
def _quality_mode(self, quality=None): """ Human-readable assessment of the current JPEG quality settings. """ quality = quality or self.quality if self._model is not None and self._model.trainable: return 'trainable QF~{}/{}'.format( jpeg_qf_estimation(self._model._q_mtx_luma, 0), jpeg_qf_estimation(self._model._q_mtx_chroma, 1)) elif is_number(quality): return 'QF={}'.format(quality) elif hasattr(quality, '__getitem__') and len(quality) == 2: return 'QF~[{},{}]'.format(*quality) elif hasattr(quality, '__getitem__') and len(quality) > 2: return 'QF~{{{}}}'.format(','.join(str(x) for x in quality)) else: return 'QF=?'
def update(self, **params): # Iterate over submitted values for key, value in params.items(): if key in self._specs: # Get specification for the current parameter _, dtype, validation = self._specs[key] # Accept the new value if it: # is not None # is not np.nan (for numerical types) # passes validation checks if value is not None: if utils.is_number(value) and np.isnan(value): raise ValueError( 'Invalid value {} for attribute {}'.format( value, key)) candidate = value if dtype is None else dtype(value) # Validation checks if validation is not None: # 1. if tuple - treat as min and max values if type(validation) == tuple and len(validation) == 2: if validation[ 0] is not None and candidate < validation[ 0]: raise ValueError( '{}: {} fails minimum validation check >= {}!' .format(key, candidate, validation[0])) if validation[ 1] is not None and candidate > validation[ 1]: raise ValueError( '{}: {} fails maximum validation check (<= {})!' .format(key, candidate, validation[1])) # 2. if set - treat as a set of valid values if type(validation) == set: if candidate not in validation: raise ValueError( '{}: {} is not an allowed value ({})!'. format(key, candidate, validation)) # 3. if both string - treat as a regular expression match if type(validation) == str and dtype == str: if validation not in candidate: # if not regex.match(validation, candidate): raise ValueError( '{}: {} does not match regex ({})!'.format( key, candidate, validation)) # 4. if function - run custom validation code if callable(validation): if not validation(candidate): raise ValueError( '{}: {} failed custom validation check!'. format(key, candidate)) self._values[key] = candidate else: raise ValueError('Unexpected parameter: {}!'.format(key))