def process(self, batch_x, quality=None, return_entropy=False): """ Compress a batch of images (NHW3:rgb) with a given quality factor: - if quality is a number - use this quality level - if quality is an iterable with 2 numbers - use a random integer from that range - if quality is an iterable with >2 numbers - use a random value from that set """ quality = self.quality if quality is None else quality quality = int(quality) if not is_valid_quality(quality): raise ValueError('Invalid or unspecified JPEG quality!') # if hasattr(quality, '__getitem__') and len(quality) > 2: # quality = int(np.random.choice(quality)) # elif hasattr(quality, '__getitem__') and len(quality) == 2: # quality = np.random.randint(quality[0], quality[1]) # elif is_number(quality) and quality >= 1 and quality <= 100: # quality = int(quality) #else: # raise ValueError('Invalid quality! {}'.format(quality)) if self._model is None: if not isinstance(batch_x, np.ndarray): batch_x = batch_x.numpy() if return_entropy: return jpeg_helpers.compress_batch(batch_x, quality)[0], np.nan else: return jpeg_helpers.compress_batch(batch_x, quality)[0] else: if quality != self.quality: old_q_luma, old_q_chroma = self._model._q_mtx_luma, self._model._q_mtx_chroma self._model._q_mtx_luma = jpeg_qtable(quality, 0) self._model._q_mtx_chroma = jpeg_qtable(quality, 1) y, X = self._model(batch_x) if quality != self.quality: self._model._q_mtx_luma, self._model._q_mtx_chroma = old_q_luma, old_q_chroma if return_entropy: # TODO This currently takes too much memory entropy = tf_helpers.entropy(X, self._model.quantization.codebook, v=5, gamma=5)[0] return y, X, entropy return y
def get_jpeg_df(directory, write_files=False, effective_bytes=True, force_calc=False): """ Compute and return (as Pandas DF) the rate distortion curve for JPEG. The result is saved as a CSV file in the source directory. If the file exists, the DF is loaded and returned. Files are saved as JPEG using imageio. """ files, _ = loading.discover_images(directory, n_images=-1, v_images=0) batch_x = loading.load_images(files, directory, load='y') batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1) # Get trade-off for JPEG quality_levels = np.arange(95, 5, -5) df_jpeg_path = os.path.join(directory, 'jpeg.csv') if os.path.isfile(df_jpeg_path) and not force_calc: logger.info('Restoring JPEG stats from {}'.format(df_jpeg_path)) df = pd.read_csv(df_jpeg_path, index_col=False) else: df = pd.DataFrame(columns=[ 'image_id', 'filename', 'codec', 'quality', 'ssim', 'psnr', 'msssim', 'msssim_db', 'bytes', 'bpp' ]) with tqdm.tqdm(total=len(files) * len(quality_levels), ncols=120, desc='JPEG') as pbar: for image_id, filename in enumerate(files): # Read the original image image = batch_x[image_id] for qi, q in enumerate(quality_levels): # Compress images and get effective bytes (only image data - no headers) image_compressed, image_bytes = jpeg_helpers.compress_batch( image, q, effective=effective_bytes) if write_files: image_dir = os.path.join(directory, os.path.splitext(filename)[0]) if not os.path.isdir(image_dir): os.makedirs(image_dir) image_path = os.path.join(image_dir, 'jpeg_q{:03d}.png'.format(q)) imageio.imwrite(image_path, (255 * image_compressed).astype( np.uint8)) msssim_value = msssim(image, image_compressed, MAX=1).real df = df.append( { 'image_id': image_id, 'filename': filename, 'codec': 'jpeg', 'quality': q, 'ssim': compare_ssim(image, image_compressed, multichannel=True, data_range=1), 'psnr': compare_psnr(image, image_compressed, data_range=1), 'msssim': msssim_value, 'msssim_db': -10 * np.log10(1 - msssim_value), 'bytes': image_bytes, 'bpp': 8 * image_bytes / image.shape[0] / image.shape[1] }, ignore_index=True) pbar.set_postfix(image_id=image_id, quality=q) pbar.update(1) df.to_csv(os.path.join(directory, 'jpeg.csv'), index=False) return df
def match_jpeg(model, batch_x, axes=None, match='ssim'): # Compress using DCN and get number of bytes batch_y, bytes_dcn = codec.simulate_compression(batch_x, model) ssim_dcn = compare_ssim(batch_x.squeeze(), batch_y.squeeze(), multichannel=True, data_range=1) bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1]) target = ssim_dcn if match == 'ssim' else bpp_dcn try: jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(), target, match=match) except: if match == 'ssim': jpeg_quality = 95 if ssim_dcn > 0.8 else 10 else: jpeg_quality = 95 if bpp_dcn > 3 else 10 print( 'WARNING Could not find a matching JPEG quality factor - guessing {}' .format(jpeg_quality)) # Compress using JPEG batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0], jpeg_quality, effective=True) ssim_jpeg = compare_ssim(batch_x.squeeze(), batch_j.squeeze(), multichannel=True, data_range=1) bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1]) # Get stats code_book = model.get_codebook() batch_z = model.compress(batch_x) counts = utils.qhist(batch_z, code_book) counts = counts.clip(min=1) probs = counts / counts.sum() entropy = -np.sum(probs * np.log2(probs)) # Print report print('DCN : {}'.format(model.model_code)) print('Pixels : {}x{} = {:,} px'.format( batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1]))) print('Bitmap : {:,} bytes'.format(np.prod(batch_x.shape))) print('Code-book size : {} elements from {} to {}'.format( len(code_book), min(code_book), max(code_book))) print('Entropy : {:.2f} bits per symbol'.format(entropy)) print('Latent size : {:,}'.format(np.prod(batch_z.shape))) print('PPF Naive : {:,.0f} --> {:,.0f} bytes [{} bits per element]'. format( np.prod(batch_z.shape) * np.log2(len(code_book)) / 8, np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8, np.ceil(np.log2(len(code_book))))) print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format( np.prod(batch_z.shape) * entropy / 8, np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1]))) print('FSE Coded : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format( bytes_dcn, bpp_dcn, ssim_dcn)) print( 'JPEG (Q={:2d}) : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers' .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg)) # Plot results if axes is None: fig, axes = plotting.sub(6, ncols=3) fig.set_size_inches(12, 10) fig.tight_layout() else: fig = axes[0].figure # Plot full-resolution plotting.quickshow(batch_x, 'Original ({0}x{0})'.format(batch_x.shape[1]), axes=axes[0]) plotting.quickshow(batch_y, 'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn), axes=axes[1]) plotting.quickshow(batch_j, 'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format( jpeg_quality, ssim_jpeg, bpp_jpg), axes=axes[2]) # Plot zoom crop_size = max([64, batch_x.shape[1] // 4]) plotting.quickshow(utils.crop_middle(batch_x, crop_size), 'Original crop ({0}x{0})'.format(crop_size), axes=axes[3]) plotting.quickshow(utils.crop_middle(batch_y, crop_size), 'DCN crop ({0}x{0})'.format(crop_size), axes=axes[4]) plotting.quickshow(utils.crop_middle(batch_j, crop_size), 'JPEG crop ({0}x{0})'.format(crop_size), axes=axes[5]) return fig