def get_dcn_df(directory, model_directory, write_files=False, force_calc=False): """ Compute and return (as Pandas DF) the rate distortion curve for the learned DCN codec. The result is saved as a CSV file in the source directory. If the file exists, the DF is loaded and returned. """ # Discover test files files, _ = loading.discover_images(directory, n_images=-1, v_images=0) batch_x = loading.load_images(files, directory, load='y') batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1) # Create a new table for the DCN df = pd.DataFrame(columns=[ 'image_id', 'filename', 'model_dir', 'codec', 'ssim', 'psnr', 'msssim', 'msssim_db', 'entropy', 'bytes', 'bpp', 'layers', 'quantization', 'entropy_reg', 'codebook', 'latent', 'latent_shape', 'n_features' ]) # Discover available models model_dirs = list(Path(model_directory).glob('**/progress.json')) logger.info('Found {} models'.format(len(model_dirs))) df_path = os.path.join( directory, 'dcn-{}.csv'.format( [x for x in fsutil.split(model_directory) if len(x) > 0][-1])) if os.path.isfile(df_path) and not force_calc: logger.info('Restoring DCN stats from {}'.format(df_path)) df = pd.read_csv(df_path, index_col=False) else: for model_dir in model_dirs: logger.info('Processing model dir: {}'.format(model_dir)) dcn = codec.restore( os.path.split(str(model_dir))[0], batch_x.shape[1]) # Dump compressed images for image_id, filename in enumerate(files): try: batch_y, image_bytes = codec.simulate_compression( batch_x[image_id:image_id + 1], dcn) batch_z = dcn.compress(batch_x[image_id:image_id + 1]) entropy = helpers.stats.entropy(batch_z, dcn.get_codebook()) except Exception as e: logger.error( 'Error while processing {} with {} : {}'.format( filename, dcn.model_code, e)) raise e if write_files: image_dir = os.path.join(directory, os.path.splitext(filename)[0]) if not os.path.isdir(image_dir): os.makedirs(image_dir) image_path = os.path.join( image_dir, dcn.model_code.replace('/', '-') + '.png') imageio.imwrite(image_path, (255 * batch_y[0]).astype(np.uint8)) msssim_value = msssim(batch_x[image_id], batch_y[0], MAX=1).real df = df.append( { 'image_id': image_id, 'filename': filename, 'model_dir': os.path.relpath( os.path.split(str(model_dir))[0], model_directory).replace(dcn.scoped_name, ''), 'codec': dcn.model_code, 'ssim': compare_ssim(batch_x[image_id], batch_y[0], multichannel=True, data_range=1), 'psnr': compare_psnr( batch_x[image_id], batch_y[0], data_range=1), 'msssim': msssim_value, 'msssim_db': -10 * np.log10(1 - msssim_value), 'entropy': entropy, 'bytes': image_bytes, 'bpp': 8 * image_bytes / batch_x[image_id].shape[0] / batch_x[image_id].shape[1], 'layers': dcn.n_layers if 'n_layers' in dcn._h else None, 'quantization': '{}-{:.0f}bpf'.format(dcn._h.rounding, dcn.latent_bpf), 'entropy_reg': dcn.entropy_weight, 'codebook': dcn._h.rounding, 'latent': dcn.n_latent, 'latent_shape': '{}x{}x{}'.format(*dcn.latent_shape[1:]), 'n_features': dcn.latent_shape[-1] }, ignore_index=True) df.to_csv(df_path, index=False) return df
def match_jpeg(model, batch_x, axes=None, match='ssim'): # Compress using DCN and get number of bytes batch_y, bytes_dcn = codec.simulate_compression(batch_x, model) ssim_dcn = compare_ssim(batch_x.squeeze(), batch_y.squeeze(), multichannel=True, data_range=1) bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1]) target = ssim_dcn if match == 'ssim' else bpp_dcn try: jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(), target, match=match) except: if match == 'ssim': jpeg_quality = 95 if ssim_dcn > 0.8 else 10 else: jpeg_quality = 95 if bpp_dcn > 3 else 10 print( 'WARNING Could not find a matching JPEG quality factor - guessing {}' .format(jpeg_quality)) # Compress using JPEG batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0], jpeg_quality, effective=True) ssim_jpeg = compare_ssim(batch_x.squeeze(), batch_j.squeeze(), multichannel=True, data_range=1) bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1]) # Get stats code_book = model.get_codebook() batch_z = model.compress(batch_x) counts = utils.qhist(batch_z, code_book) counts = counts.clip(min=1) probs = counts / counts.sum() entropy = -np.sum(probs * np.log2(probs)) # Print report print('DCN : {}'.format(model.model_code)) print('Pixels : {}x{} = {:,} px'.format( batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1]))) print('Bitmap : {:,} bytes'.format(np.prod(batch_x.shape))) print('Code-book size : {} elements from {} to {}'.format( len(code_book), min(code_book), max(code_book))) print('Entropy : {:.2f} bits per symbol'.format(entropy)) print('Latent size : {:,}'.format(np.prod(batch_z.shape))) print('PPF Naive : {:,.0f} --> {:,.0f} bytes [{} bits per element]'. format( np.prod(batch_z.shape) * np.log2(len(code_book)) / 8, np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8, np.ceil(np.log2(len(code_book))))) print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format( np.prod(batch_z.shape) * entropy / 8, np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1]))) print('FSE Coded : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format( bytes_dcn, bpp_dcn, ssim_dcn)) print( 'JPEG (Q={:2d}) : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers' .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg)) # Plot results if axes is None: fig, axes = plotting.sub(6, ncols=3) fig.set_size_inches(12, 10) fig.tight_layout() else: fig = axes[0].figure # Plot full-resolution plotting.quickshow(batch_x, 'Original ({0}x{0})'.format(batch_x.shape[1]), axes=axes[0]) plotting.quickshow(batch_y, 'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn), axes=axes[1]) plotting.quickshow(batch_j, 'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format( jpeg_quality, ssim_jpeg, bpp_jpg), axes=axes[2]) # Plot zoom crop_size = max([64, batch_x.shape[1] // 4]) plotting.quickshow(utils.crop_middle(batch_x, crop_size), 'Original crop ({0}x{0})'.format(crop_size), axes=axes[3]) plotting.quickshow(utils.crop_middle(batch_y, crop_size), 'DCN crop ({0}x{0})'.format(crop_size), axes=axes[4]) plotting.quickshow(utils.crop_middle(batch_j, crop_size), 'JPEG crop ({0}x{0})'.format(crop_size), axes=axes[5]) return fig