def compare_batches(batch_a, batch_b, labels=None, fig=None, figwidth=4, nrows=3, transpose=False): n_images = min(len(batch_a), len(batch_b)) labels = labels or ['', ''] fig, axes = plots.sub(n_images * nrows, figwidth, ncols=n_images, fig=fig, transpose=transpose) for i, (img_a, img_b) in enumerate(zip(batch_a, batch_b)): label_a = f'(A) {labels[0]}' plots.image(img_a, label_a, axes=axes[i + n_images * 0]) psnr = metrics.psnr(img_a, img_b) ssim = metrics.ssim(img_a, img_b) label_b = f'(B) {labels[1]}: {psnr:.1f} dB / {ssim:.3f}' plots.image(img_b, label_b, axes=axes[i + n_images * 1]) diff_ab = img_b - img_a diff_abs = np.abs(img_b - img_a) diff_mean = diff_ab.mean() if nrows > 2: plots.image(image.normalize(diff_ab), f'A - B: mean abs {diff_mean:.3f}', axes=axes[i + n_images*2]) if nrows > 3: plots.image(image.normalize(diff_abs, 0.1), f'|A - B|: mean abs {diff_mean:.3f}', axes=axes[i + n_images*3]) return fig
def validate_dcn(dcn, data, save_dir=False, epoch=0, show_ref=False): """ Computes validation metrics for a compression model (DCN). (If not a DCN, the function returns immediately). If requested, plot compressed images to a JPEG image. :param dcn: the DCN model :param data: the dataset (instance of Dataset) :param data: the dataset (instance of Dataset) :param save_dir: path to the directory where figures should be generated :param epoch: epoch counter to be appended to the output filename :param show_ref: whether to show only the compressed image or also the input image as reference :return: tuple of lists with per-image measurements of (ssims, psnrs, losses, entropies) """ if not isinstance(dcn, DCN): return # Compute latent representations and compressed output batch_x = data.next_validation_batch(0, data.count_validation) if isinstance(batch_x, tuple): batch_x = batch_x[-1] batch_y, entropy = dcn.process(batch_x, return_entropy=True) entropy = float(entropy.numpy()) ssim = metrics.ssim(batch_x, batch_y.numpy()).tolist() psnr = metrics.psnr(batch_x, batch_y.numpy()).tolist() loss = float(dcn.loss(batch_x, batch_y, entropy).numpy()) # If requested, plot a figure with input/output pairs if save_dir is not None: images_x = np.minimum(data.count_validation, 10 if not show_ref else 5) images_y = np.ceil(data.count_validation / images_x) fig = Figure(figsize=(20, 20 / images_x * images_y * (1 if not show_ref else 0.5))) for b in range(data.count_validation): ax = fig.add_subplot(images_y, images_x, b + 1) plots.image(np.concatenate( (batch_x[b], batch_y[b]), axis=1) if show_ref else batch_y[b], '{:.1f} / {:.2f}'.format(psnr[b], ssim[b]), axes=ax) if not os.path.exists(save_dir): os.makedirs(save_dir) fig.savefig('{}/dcn_validation_{:05d}.jpg'.format(save_dir, epoch), bbox_inches='tight', dpi=100, quality=90) del fig return { 'ssim': float(np.mean(ssim)), 'psnr': float(np.mean(psnr)), 'loss': loss, 'entropy': entropy }
def validate(model, data, out_directory, savefig=False, epoch=0, show_ref=False, loss_metric='L2'): ssims, psnrs, losss = [], [], [] if loss_metric not in ['L2', 'L1', 'SSIM', 'MS-SSIM']: raise ValueError('Unsupported loss ({})!'.format(loss_metric)) if savefig: images_x = np.minimum(data.count_validation, 20 if not show_ref else 10) images_y = np.ceil(data.count_validation / images_x) fig = Figure(figsize=(40, 1.1 * 40 / images_x * images_y * (1 if not show_ref else 0.5))) developed_out = np.zeros_like(data['validation']['y'], dtype=np.float32) for b in range(data.count_validation): # Fetch the next example and develop the RGB image example_x, example_y = data.next_validation_batch(b, 1) developed = model.process(example_x).numpy().clip(0, 1) developed_out[b, :, :, :] = developed developed = developed.squeeze() reference = example_y.squeeze() # Compute loss & quality metrics ssim = float(metrics.ssim(reference, developed)) psnr = float(metrics.psnr(reference, developed)) if loss_metric == 'L2': loss = metrics.mse(255 * reference, 255 * developed) elif loss_metric == 'L1': loss = metrics.mae(255 * reference, 255 * developed) elif loss_metric == 'SSIM': loss = 255 * (1 - metrics.ssim(reference, developed)) else: raise ValueError('Unsupported loss ({})!'.format(loss_metric)) ssims.append(ssim) psnrs.append(psnr) losss.append(loss) if savefig: ax = fig.add_subplot(images_y, images_x, b + 1) if show_ref: ax.imshow(np.concatenate((reference, developed), axis=1)) else: ax.imshow(developed) ax.set_xticks([]) ax.set_yticks([]) label_index = int( b // (data.count_validation / len(data.files['validation']))) ax.set_title('{} : {:.1f} dB / {:.2f}'.format( data.files['validation'][label_index], psnr, ssim), fontsize=6) if savefig: if not os.path.exists(out_directory): os.makedirs(out_directory) fig.savefig(os.path.join(out_directory, 'validation_{:05d}.jpg'.format(epoch)), bbox_inches='tight', dpi=150) del fig return ssims, psnrs, losss, developed_out
def match_jpeg(model, batch_x, axes=None, match='ssim'): # Compress using DCN and get number of bytes batch_y, bytes_dcn = codec.simulate_compression(batch_x, model) ssim_dcn = metrics.ssim(batch_x.squeeze(), batch_y.squeeze()).mean() bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1]) target = ssim_dcn if match == 'ssim' else bpp_dcn try: jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(), target, match=match) except: if match == 'ssim': jpeg_quality = 95 if ssim_dcn > 0.8 else 10 else: jpeg_quality = 95 if bpp_dcn > 3 else 10 print( 'WARNING Could not find a matching JPEG quality factor - guessing {}' .format(jpeg_quality)) # Compress using JPEG batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0], jpeg_quality, effective=True) ssim_jpeg = metrics.ssim(batch_x.squeeze(), batch_j.squeeze()).mean() bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1]) # Get stats code_book = model.get_codebook() batch_z = model.compress(batch_x).numpy() counts = helpers.stats.hist(batch_z, code_book) counts = counts.clip(min=1) probs = counts / counts.sum() entropy = -np.sum(probs * np.log2(probs)) # Print report print('DCN : {}'.format(model.model_code)) print('Pixels : {}x{} = {:,} px'.format( batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1]))) print('Bitmap : {:,} bytes'.format(np.prod(batch_x.shape))) print('Code-book size : {} elements from {} to {}'.format( len(code_book), min(code_book), max(code_book))) print('Entropy : {:.2f} bits per symbol'.format(entropy)) print('Latent size : {:,}'.format(np.prod(batch_z.shape))) print('PPF Naive : {:,.0f} --> {:,.0f} bytes [{} bits per element]'. format( np.prod(batch_z.shape) * np.log2(len(code_book)) / 8, np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8, np.ceil(np.log2(len(code_book))))) print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format( np.prod(batch_z.shape) * entropy / 8, np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1]))) print('FSE Coded : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format( bytes_dcn, bpp_dcn, ssim_dcn)) print( 'JPEG (Q={:2d}) : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers' .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg)) # Plot results if axes is None: fig, axes = plots.sub(6, ncols=3) fig.set_size_inches(12, 10) fig.tight_layout() else: fig = axes[0].figure # Plot full-resolution plots.image(batch_x, 'Original ({0}x{0})'.format(batch_x.shape[1]), axes=axes[0]) plots.image(batch_y, 'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn), axes=axes[1]) plots.image(batch_j, 'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format( jpeg_quality, ssim_jpeg, bpp_jpg), axes=axes[2]) # Plot zoom crop_size = max([64, batch_x.shape[1] // 4]) plots.image(helpers.image.crop_middle(batch_x, crop_size), 'Original crop ({0}x{0})'.format(crop_size), axes=axes[3]) plots.image(helpers.image.crop_middle(batch_y, crop_size), 'DCN crop ({0}x{0})'.format(crop_size), axes=axes[4]) plots.image(helpers.image.crop_middle(batch_j, crop_size), 'JPEG crop ({0}x{0})'.format(crop_size), axes=axes[5]) return fig
def show_example(model, batch_x): # Compress and decompress model batch_z = model.compress(batch_x).numpy() batch_y = model.decompress(batch_z).numpy() # Get empirical histogram of the latent representation codebook = model.get_codebook() qmin = np.floor(codebook[0]) qmax = np.ceil(codebook[-1]) bin_centers = np.arange(qmin - 1, qmax + 1, 0.1) bin_boundaries = np.convolve(bin_centers, [0.5, 0.5], mode='valid') bin_centers = bin_centers[1:-1] hist_emp = np.histogram(batch_z.reshape((-1, )), bins=bin_boundaries, density=True)[0] hist_emp = np.maximum(hist_emp, 1e-9) hist_emp = hist_emp / hist_emp.sum() # Get TF histogram estimate based on soft quantization hist = helpers.stats.hist(batch_z, codebook) hist = hist / hist.sum() # Entropy entropy = -np.sum(hist * np.log2(hist)) entropy_emp = -np.sum(hist_emp * np.log2(hist_emp)) fig, axes = plots.sub(2, ncols=1) fig.set_size_inches(12, 10) axes[0].plot(bin_centers, hist_emp / hist_emp.max(), 'r-') axes[0].plot(codebook, hist / hist.max(), '-bo') axes[0].legend([ 'Empirical H={:.2f}'.format(entropy_emp), 'TF estimate (soft) H={:.2f}'.format(entropy) ]) axes[0].set_ylabel('normalized frequency') axes[0].set_xlabel('latent values') # Thumbnails indices = np.argsort(np.var(batch_x, axis=(1, 2, 3)))[::-1] thumbs_pairs_few = np.concatenate((batch_x[indices], batch_y[indices]), axis=0) thumbs_few = ( 255 * plots.thumbnails(thumbs_pairs_few, ncols=len(batch_x))).astype( np.uint8) ssim_values = [ metrics.ssim(batch_x[i], batch_y[i]).mean() for i in range(len(batch_x)) ] plots.image(thumbs_few, 'Sample reconstructions, ssim={:.3f}'.format( np.mean(ssim_values)), axes=axes[1]) fig.tight_layout() return fig
def compare_images_ab_ref(img_ref, img_a, img_b, labels=None, extras=False, fig=None): labels = labels or ['target', '', ''] img_a = img_a.squeeze() img_b = img_b.squeeze() img_ref = img_ref.squeeze() fig, axes = plots.sub(9 if extras else 3, ncols=3, fig=fig) # Index of the last axes j = 3 if extras else 2 plots.image(img_ref, '(T) {}'.format(labels[0]), axes=axes[0]) label_a = '(A) {}: {:.1f} dB / {:.3f}'.format(labels[1], metrics.psnr(img_ref, img_a), metrics.ssim(img_ref, img_a)) plots.image(img_a, label_a, axes=axes[1]) label_b = '(B) {}: {:.1f} dB / {:.3f}'.format(labels[2], metrics.psnr(img_ref, img_b), metrics.ssim(img_ref, img_b)) plots.image(img_b, label_b, axes=axes[j]) # A hack to allow image axes to zoom together axes[1].get_shared_x_axes().join(axes[0], axes[1]) axes[j].get_shared_x_axes().join(axes[0], axes[j]) axes[1].get_shared_y_axes().join(axes[0], axes[1]) axes[j].get_shared_y_axes().join(axes[0], axes[j]) if not extras: return fig # Compute and plot difference images diff_a = np.abs(img_a - img_ref) diff_a_mean = diff_a.mean() diff_a = image.normalize(diff_a, 0.1) diff_b = np.abs(img_b - img_ref) diff_b_mean = diff_b.mean() diff_b = image.normalize(diff_b, 0.1) diff_ab = np.abs(img_b - img_a) diff_ab_mean = diff_ab.mean() diff_ab = image.normalize(diff_ab, 0.1) plots.image(diff_a, 'T - A: mean abs {:.3f}'.format(diff_a_mean), axes=axes[2]) plots.image(diff_b, 'T - B: mean abs {:.3f}'.format(diff_b_mean), axes=axes[6]) plots.image(diff_ab, 'A - B: mean abs {:.3f}'.format(diff_ab_mean), axes=axes[4]) # Compute and plot spectra fft_a = fft_log_norm(diff_a) fft_b = fft_log_norm(diff_b) # fft_ab = utils.normalize(np.abs(fft_a - fft_b)) fft_ab = image.normalize(np.abs(fft_log_norm(img_b) - fft_log_norm(img_a)), 0.01) plots.image(fft_a, 'FFT(T - A)', axes=axes[5]) plots.image(fft_b, 'FFT(T - B)', axes=axes[7]) plots.image(fft_ab, 'FFT(A) - FFT(B)', axes=axes[8]) return fig
def develop_image(pipeline, camera=None, batch=None, image=None, patch_size=0, patches=2, root_dir='./data', pipeline_args=None): """ Display a patch developed by a neural imaging pipeline. """ if camera is not None: supported_cameras = fsutil.listdir( os.path.join(root_dir, 'models', 'nip'), '.*') if camera not in supported_cameras: raise ValueError( 'Camera data not found ({})! Available cameras: {}'.format( camera, ', '.join(supported_cameras))) root_dirname = os.path.join(root_dir, 'models', 'nip', camera) data_dirname = os.path.join(root_dir, 'raw', 'training_data', camera) if patch_size != 0 and (patch_size < 4 or patch_size > 2048): raise ValueError('Patch size seems to be invalid!') # Lazy imports to minimize delay for invalid command line parameters import numpy as np import imageio as io import matplotlib.pyplot as plt import tensorflow as tf from models import pipelines # Construct the NIP model --------------------------------------------------------------------- if os.path.isdir(pipeline): # Restore a NIP model from a training log model = tfmodel.restore(pipeline, pipelines) else: # Construct the NIP model from class name (and optional arguments) if pipeline_args is None: model = getattr(pipelines, pipeline)() else: model = getattr(pipelines, pipeline)(**pipeline_args) loaded_model = False candidate_dirs = [ os.path.join(root_dirname, model.model_code), os.path.join(root_dirname) ] for candidate in candidate_dirs: if os.path.isdir(candidate): model.load_model(candidate) loaded_model = True break if not loaded_model: raise FileNotFoundError( f'Could not find the corresponding model: {candidate_dirs}') # Load image(s) ------------------------------------------------------------------------------- if image is None and batch is not None: print('Loading a batch of {} images'.format(batch)) data = dataset.Dataset(data_dirname, n_images=0, v_images=batch, val_rgb_patch_size=patch_size or 256, val_n_patches=patches) sample_x, sample_y = data.next_validation_batch( 0, data.count_validation) with open('config/cameras.json') as f: cameras = json.load(f) cfa, srgb = cameras[camera]['cfa'], np.array( cameras[camera]['srgb']) elif image is not None: print('Loading a RAW image {}'.format(image)) sample_x, cfa, srgb, _ = raw.unpack(image, expand=True) sample_y = raw.process(image, brightness=None, expand=True) if isinstance(model, pipelines.ClassicISP): print('Configuring ISP to CFA: {} & sRGB {}'.format( cfa, srgb.round(2).tolist())) model.set_cfa_pattern(cfa) model.set_srgb_conversion(srgb) sample_Y = model.process(sample_x).numpy() if patch_size > 0: xx = (sample_y.shape[2] - patch_size) // 2 yy = (sample_y.shape[1] - patch_size) // 2 sample_y = sample_y[:, yy:yy + patch_size, xx:xx + patch_size, :] sample_Y = sample_Y[:, yy:yy + patch_size, xx:xx + patch_size, :] psnrs = metrics.psnr(sample_y, sample_Y) ssims = metrics.ssim(sample_y, sample_Y) print('sample x: {}'.format(sample_x.shape)) print('sample y: {}'.format(sample_y.shape)) print('sample Y: {}'.format(sample_Y.shape)) # Plot images --------------------------------------------------------------------------------- if len(sample_y) > 1: sample_y = plots.thumbnails(sample_y, batch, True) sample_Y = plots.thumbnails(sample_Y, batch, True) else: sample_y = sample_y.squeeze() sample_Y = sample_Y.squeeze() print('thumbnails: {}'.format(sample_y.shape)) ncols = 1 if sample_y.shape[1] > sample_y.shape[0] else 2 nrows = 2 if ncols == 1 else 1 fig, axes = plt.subplots(nrows, ncols) plots.image(sample_Y, '{}, PSNR={:.1f} dB, SSIM={:.2f} : {{}}'.format( model.model_code, float(psnrs.mean()), float(ssims.mean())), axes=axes[0]) plots.image(sample_y, 'Target RGB images () : {}', axes=axes[1]) plt.show() plt.close()
def get_ssim(q): image_j = compress_batch(image, q, subsampling=subsampling)[0].squeeze() c_ssim = metrics.ssim(image, image_j) return c_ssim - target
def validate_nip(model, data, save_dir=False, epoch=0, show_ref=False, loss_type='L2'): """ Develops image patches using the given NIP and returns standard image quality measures. If requested, resulting patches are visualized as thumbnails and saved to a directory. :param model: the NIP model :param data: the dataset (instance of Dataset) :param data: the dataset (instance of Dataset) :param save_dir: path to the directory where figures should be generated :param epoch: epoch counter to be appended to the output filename :param show_ref: whether to show only the developed image or also the GT target :param loss_type: L1 or L2 :return: tuple of lists with per-image measurements of (ssims, psnrs, losss) """ ssims = [] psnrs = [] losss = [] # If requested, plot a figure with output/target pairs if save_dir is not None: images_x = np.minimum(data.count_validation, 10 if not show_ref else 5) images_y = np.ceil(data.count_validation / images_x) fig = Figure(figsize=(20, 20 / images_x * images_y * (1 if not show_ref else 0.5))) developed_out = np.zeros_like(data['validation']['y'], dtype=np.float32) for b in range(data.count_validation): example_x, example_y = data.next_validation_batch(b, 1) developed = model.process(example_x).numpy().clip(0, 1) developed_out[b, :, :, :] = developed developed = developed[:, :, :, :].squeeze() reference = example_y.squeeze() # Compute stats ssim = metrics.ssim(reference, developed).mean() psnr = metrics.psnr(reference, developed).mean() if loss_type == 'L2': loss = np.mean(np.power(reference - developed, 2.0)) elif loss_type == 'L1': loss = np.mean(np.abs(reference - developed)) else: raise ValueError('Invalid loss! Use either L1 or L2.') ssims.append(ssim) psnrs.append(psnr) losss.append(loss) # Add images to the plot if save_dir is not None: ax = fig.add_subplot(images_y, images_x, b + 1) plots.image(np.concatenate( (reference, developed), axis=1) if show_ref else developed, '{:.1f} dB / {:.2f}'.format(psnr, ssim), axes=ax) if save_dir is not None: if not os.path.exists(save_dir): os.makedirs(save_dir) fig.savefig('{}/nip_validation_{:05d}.jpg'.format(save_dir, epoch), bbox_inches='tight', dpi=100, quality=90) del fig return ssims, psnrs, losss