def compress_n_stats(batch_x, dcn): batch_y = np.zeros_like(batch_x) stats = { 'ssim': np.zeros((batch_x.shape[0])), 'psnr': np.zeros((batch_x.shape[0])), 'entropy': np.zeros((batch_x.shape[0])), 'bytes': np.zeros((batch_x.shape[0])), 'bpp': np.zeros((batch_x.shape[0])) } for image_id in range(batch_x.shape[0]): batch_y[image_id], image_bytes = simulate_compression( batch_x[image_id:image_id + 1], dcn) batch_z = dcn.compress(batch_x[image_id:image_id + 1]) stats['bytes'][image_id] = image_bytes stats['entropy'][image_id] = utils.entropy(batch_z, dcn.get_codebook()) stats['ssim'][image_id] = compare_ssim(batch_x[image_id], batch_y[image_id], multichannel=True, data_range=1) stats['psnr'][image_id] = compare_psnr(batch_x[image_id], batch_y[image_id], data_range=1) stats['bpp'][image_id] = 8 * image_bytes / batch_x[image_id].shape[ 0] / batch_x[image_id].shape[1] if batch_x.shape[0] == 1: for k in stats.keys(): stats[k] = stats[k][0] return batch_y, stats
def visualize_distribution(dcn, data, ax=None, title=None): title = '' if title is None else title+' ' if type(data) is not np.ndarray: sample_batch_size = np.min((100, data.count_validation)) batch_x = data.next_validation_batch(0, sample_batch_size) else: batch_x = data # Fetch latent distribution for the current batch batch_z = dcn.compress(batch_x) batch_z = batch_z.reshape((-1,)).T # Get current version of the quantization codebook codebook = dcn.get_codebook().tolist() # Find x limits for plotting if dcn._h.rounding == 'identity': qmax = np.ceil(np.max(np.abs(batch_z))) qmin = -qmax else: qmin = np.floor(codebook[0]) qmax = np.ceil(codebook[-1]) feed_dict = {dcn.x: batch_x} if hasattr(dcn, 'is_training'): feed_dict[dcn.is_training] = True # Get approximation of the soft quantization structures used for entropy estimation histogram = dcn.sess.run(dcn.histogram, feed_dict=feed_dict).reshape((-1,)) histogram = histogram / histogram.max() histogram = histogram.reshape((-1)).tolist() # Create a dense version of the quantization bins bin_centers = np.arange(qmin - 1, qmax + 1, 0.1) bin_boundaries = np.convolve(bin_centers, [0.5, 0.5], mode='valid') bin_centers = bin_centers[1:-1] # Compute empirical histogram based on latent representation hist = np.histogram(batch_z, bins=bin_boundaries, density=True)[0] hist = hist / hist.max() entropy = utils.entropy(batch_z, codebook) ticks = np.unique(np.round(np.percentile(batch_z, [1, 5, 25, 50, 75, 95, 99]))) if ax is None: fig = plt.figure(figsize=(10, 2)) ax = fig.gca() ax.set_xlim([qmin - 1, qmax + 1]) ax.set_xticks(ticks) ax.stem(bin_centers, hist, linefmt='r:', markerfmt='r.') # width=bin_centers[1] - bin_centers[0] ax.bar(codebook, histogram, width=(codebook[1] - codebook[0]) / 2, color='b', alpha=0.5) ax.set_title('{}QLR histogram (H={:.1f})'.format(title, entropy)) ax.legend(['Quantized values', 'Soft estimate'], loc='upper right') # Render the plot as a PNG image and return a bitmap array return ax.figure
def validate_dcn(dcn, data, save_dir=False, epoch=0, show_ref=False): """ Computes validation metrics for a compression model (DCN). (If not a DCN, the function returns immediately). If requested, plot compressed images to a JPEG image. :param dcn: the DCN model :param data: the dataset (instance of IPDataset) :param save_dir: path to the directory where figures should be generated :param epoch: epoch counter to be appended to the output filename :param show_ref: whether to show only the compressed image or also the input image as reference :return: tuple of lists with per-image measurements of (ssims, psnrs, losses, entropies) """ if not isinstance(dcn, DCN): return # Compute latent representations and compressed output batch_x = data.next_validation_batch(0, data.count_validation) if isinstance(batch_x, tuple): batch_x = batch_x[-1] batch_z = dcn.compress(batch_x, direct=True) batch_y = dcn.process(batch_x, direct=True) # Compute quality measures and entropy statistics codebook = dcn.get_codebook() ssims = [compare_ssim(batch_x[b], batch_y[b], multichannel=True, data_range=1) for b in range(data.count_validation)] psnrs = [compare_psnr(batch_x[b], batch_y[b], data_range=1) for b in range(data.count_validation)] losses = [dcn.sess.run(dcn.loss, feed_dict={dcn.x: batch_x[b:b + 1]}) for b in range(data.count_validation)] entropies = [utils.entropy(batch_z[b], codebook) for b in range(data.count_validation)] # If requested, plot a figure with input/output pairs if save_dir is not None: images_x = np.minimum(data.count_validation, 10 if not show_ref else 5) images_y = np.ceil(data.count_validation / images_x) fig = plt.figure(figsize=(20, 20 / images_x * images_y * (1 if not show_ref else 0.5))) for b in range(data.count_validation): ax = fig.add_subplot(images_y, images_x, b + 1) plotting.quickshow( np.concatenate((batch_x[b], batch_y[b]), axis=1) if show_ref else batch_y[b], '{:.1f} / {:.2f}'.format(psnrs[b], ssims[b]), axes=ax ) if not os.path.exists(save_dir): os.makedirs(save_dir) fig.savefig('{}/dcn_validation_{:05d}.jpg'.format(save_dir, epoch), bbox_inches='tight', dpi=100, quality=90) plt.close(fig) del fig return ssims, psnrs, losses, entropies
def get_dcn_df(directory, model_directory, write_files=False, force_calc=False): """ Compute and return (as Pandas DF) the rate distortion curve for the learned DCN codec. The result is saved as a CSV file in the source directory. If the file exists, the DF is loaded and returned. """ # Discover test files files, _ = loading.discover_files(directory, n_images=-1, v_images=0) batch_x = loading.load_images(files, directory, load='y') batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1) # Create a new table for the DCN df = pd.DataFrame(columns=[ 'image_id', 'filename', 'model_dir', 'codec', 'ssim', 'psnr', 'msssim', 'msssim_db', 'entropy', 'bytes', 'bpp', 'layers', 'quantization', 'entropy_reg', 'codebook', 'latent', 'latent_shape', 'n_features' ]) # Discover available models model_dirs = list(Path(model_directory).glob('**/progress.json')) print('Found {} models'.format(len(model_dirs))) df_path = os.path.join( directory, 'dcn-{}.csv'.format([ x for x in coreutils.splitall(model_directory) if len(x) > 0 ][-1])) if os.path.isfile(df_path) and not force_calc: print('Restoring DCN stats from {}'.format(df_path)) df = pd.read_csv(df_path, index_col=False) else: for model_dir in model_dirs: print('Processing: {}'.format(model_dir)) dcn = codec.restore_model( os.path.split(str(model_dir))[0], batch_x.shape[1]) # Dump compressed images for image_id, filename in enumerate(files): try: batch_y, image_bytes = codec.simulate_compression( batch_x[image_id:image_id + 1], dcn) batch_z = dcn.compress(batch_x[image_id:image_id + 1]) entropy = utils.entropy(batch_z, dcn.get_codebook()) except Exception as e: print('Error while processing {} with {} : {}'.format( filename, dcn.model_code, e)) raise e if write_files: image_dir = os.path.join(directory, os.path.splitext(filename)[0]) if not os.path.isdir(image_dir): os.makedirs(image_dir) image_path = os.path.join( image_dir, dcn.model_code.replace('/', '-') + '.png') imageio.imwrite(image_path, (255 * batch_y[0]).astype(np.uint8)) msssim_value = msssim(batch_x[image_id], batch_y[0], MAX=1).real df = df.append( { 'image_id': image_id, 'filename': filename, 'model_dir': os.path.relpath( os.path.split(str(model_dir))[0], model_directory).replace(dcn.scoped_name, ''), 'codec': dcn.model_code, 'ssim': compare_ssim(batch_x[image_id], batch_y[0], multichannel=True, data_range=1), 'psnr': compare_psnr( batch_x[image_id], batch_y[0], data_range=1), 'msssim': msssim_value, 'msssim_db': -10 * np.log10(1 - msssim_value), 'entropy': entropy, 'bytes': image_bytes, 'bpp': 8 * image_bytes / batch_x[image_id].shape[0] / batch_x[image_id].shape[1], 'layers': dcn.n_layers if 'n_layers' in dcn._h else None, 'quantization': '{}-{:.0f}bpf'.format(dcn._h.rounding, dcn.latent_bpf), 'entropy_reg': dcn.entropy_weight, 'codebook': dcn._h.rounding, 'latent': dcn.n_latent, 'latent_shape': '{}x{}x{}'.format(*dcn.latent_shape[1:]), 'n_features': dcn.latent_shape[-1] }, ignore_index=True) df.to_csv(df_path, index=False) return df
def train_dcn(tf_ops, training, data, directory='./data/models/dcn/playground/', overwrite=False): """ tf_ops = { 'dcn' } training { 'augmentation_probs': { 'resize': 0.0, 'flip_h': 0.5, 'flip_v': 0.5 } } """ dcn = tf_ops['dcn'] dcn.init() # Compute the number of available batches n_batches = data['training']['y'].shape[0] // training['batch_size'] v_batches = data['validation']['y'].shape[0] // training['batch_size'] # Structures for storing performance stats perf = dcn.performance caches = { 'loss': {'training': deque(maxlen=n_batches), 'validation': deque(maxlen=v_batches)}, 'entropy': {'training': deque(maxlen=n_batches), 'validation': deque(maxlen=v_batches)}, 'ssim': {'training': deque(maxlen=n_batches), 'validation': deque(maxlen=v_batches)} } n_tail = 5 learning_rate = training['learning_rate'] model_output_dirname = os.path.join(directory, dcn.model_code, dcn.scoped_name) if os.path.isdir(model_output_dirname) and not overwrite: return print('Output directory: {}'.format(model_output_dirname)) # Create a summary writer and create the necessary directories sw = dcn.get_summary_writer(model_output_dirname) with tqdm.tqdm(total=training['n_epochs'], ncols=160, desc=dcn.model_code.split('/')[-1]) as pbar: for epoch in range(0, training['n_epochs']): training['current_epoch'] = epoch if epoch > 0 and epoch % training['learning_rate_reduction_schedule'] == 0: learning_rate *= training['learning_rate_reduction_factor'] # Iterate through batches of the training data for batch_id in range(n_batches): # Pick random patch size - will be resized later for augmentation current_patch = np.random.choice(np.arange(training['patch_size'], 2 * training['patch_size']), 1) if np.random.uniform() < training['augmentation_probs'][ 'resize'] else training['patch_size'] # Sample next batch batch_x = data.next_training_batch(batch_id, training['batch_size'], current_patch) # If rescaling needed, apply if training['patch_size'] != current_patch: batch_t = np.zeros((batch_x.shape[0], training['patch_size'], training['patch_size'], 3), dtype=np.float32) for i in range(len(batch_x)): batch_t[i] = resize(batch_x[i], [training['patch_size'], training['patch_size']], anti_aliasing=True) batch_x = batch_t # Data augmentation - random horizontal flip if np.random.uniform() < training['augmentation_probs']['flip_h']: batch_x = batch_x[:, :, ::-1, :] if np.random.uniform() < training['augmentation_probs']['flip_v']: batch_x = batch_x[:, ::-1, :, :] if np.random.uniform() < training['augmentation_probs']['gamma']: batch_x = utils.batch_gamma(batch_x) # Sample dropout keep_prob = 1.0 if not training['sample_dropout'] else np.random.uniform(0.5, 1.0) # Make a training step values = dcn.training_step(batch_x, learning_rate, dropout_keep_prob=keep_prob) # TODO temporary nan hook if np.isnan(values['loss']): print('NaN loss detected - dumping current variables') codebook = dcn.get_codebook() # Get some extra stats if dcn.scale_latent: scaling = dcn.sess.run( dcn.graph.get_tensor_by_name('{}/encoder/latent_scaling:0'.format(dcn.scoped_name))) else: scaling = np.nan print('Scaling: {}'.format(scaling)) print('Codebook: {}'.format(codebook.tolist())) # Dump all variables to check which is nan for var in dcn.parameters: if np.any(np.isnan(dcn.sess.run(var))): nan_perc = np.mean(np.isnan(dcn.sess.run(var))) print('!! NaNs found in {} --> {}'.format(var.name, nan_perc)) return None for key, value in values.items(): caches[key]['training'].append(value) # Record average values for the whole epoch for key in ['loss', 'ssim', 'entropy']: perf[key]['training'].append(float(np.mean(caches[key]['training']))) # Get some extra stats if dcn.scale_latent: scaling = dcn.sess.run( dcn.graph.get_tensor_by_name('{}/encoder/latent_scaling:0'.format(dcn.scoped_name))) else: scaling = np.nan codebook = dcn.get_codebook() # Iterate through batches of the validation data if epoch % training['validation_schedule'] == 0: for batch_id in range(v_batches): batch_x = data.next_validation_batch(batch_id, training['batch_size']) batch_z = dcn.compress(batch_x, is_training=training['validation_is_training']) batch_y = dcn.decompress(batch_z) # Compute loss loss_value = np.linalg.norm(batch_x - batch_y) caches['loss']['validation'].append(loss_value) # Compute SSIM ssim_value = np.mean([ssim(batch_x[r], batch_y[r], multichannel=True, data_range=1.0) for r in range(len(batch_x))]) caches['ssim']['validation'].append(ssim_value) # Entropy entropy_value = utils.entropy(batch_z, codebook) caches['entropy']['validation'].append(entropy_value) for key in ['loss', 'ssim', 'entropy']: perf[key]['validation'].append(float(np.mean(caches[key]['validation']))) # Save current snapshot indices = np.argsort(np.var(batch_x, axis=(1, 2, 3)))[::-1] thumbs_pairs_all = np.concatenate((batch_x[indices[::2]], batch_y[indices[::2]]), axis=0) thumbs_pairs_few = np.concatenate((batch_x[indices[:5]], batch_y[indices[:5]]), axis=0) thumbs = (255 * plotting.thumbnails(thumbs_pairs_all, n_cols=training['batch_size'] // 2)).astype(np.uint8) thumbs_few = (255 * plotting.thumbnails(thumbs_pairs_few, n_cols=5)).astype(np.uint8) imageio.imsave(os.path.join(model_output_dirname, 'thumbnails-{:05d}.png'.format(epoch)), thumbs) # Sample latent space batch_z = dcn.compress(batch_x) # Save summaries to TB summary = tf.Summary() summary.value.add(tag='loss/validation', simple_value=perf['loss']['validation'][-1]) summary.value.add(tag='loss/training', simple_value=perf['loss']['training'][-1]) summary.value.add(tag='ssim/validation', simple_value=perf['ssim']['validation'][-1]) summary.value.add(tag='ssim/training', simple_value=perf['ssim']['training'][-1]) summary.value.add(tag='entropy/training', simple_value=perf['entropy']['training'][-1]) summary.value.add(tag='scaling', simple_value=scaling) summary.value.add(tag='images/reconstructed', image=summaries.log_image(rescale(thumbs_few, 1.0, anti_aliasing=True))) summary.value.add(tag='histograms/latent', histo=summaries.log_histogram(batch_z)) summary.value.add(tag='histograms/latent_approx', image=summaries.log_plot(visualize_distribution(dcn, data))) if dcn.train_codebook: summary.value.add(tag='codebook/min', simple_value=codebook.min()) summary.value.add(tag='codebook/max', simple_value=codebook.max()) summary.value.add(tag='codebook/mean', simple_value=codebook.mean()) summary.value.add(tag='codebook/diff_variance', simple_value=np.var(np.convolve(codebook, [-1, 1], mode='valid'))) summary.value.add(tag='codebook/centroids', image=summaries.log_plot(visualize_codebook(dcn))) sw.add_summary(summary, epoch) sw.flush() # Save stats to a JSON log save_progress(dcn, data, training, model_output_dirname) # Save current checkpoint dcn.save_model(model_output_dirname, epoch) # Check for convergence or model deterioration if len(perf['ssim']['validation']) > 5: current = np.mean(perf['ssim']['validation'][-n_tail:]) previous = np.mean(perf['ssim']['validation'][-(n_tail + 1):-1]) perf_change = abs((current - previous) / previous) if perf_change < training['convergence_threshold']: print('Early stopping - the model converged, validation SSIM change {:.4f}'.format(perf_change)) break if current < 0.9 * previous: print('Error - SSIM deterioration by more than 10% {:.4f} -> {:.4f}'.format(previous, current)) break progress_dict = { 'L': np.mean(perf['loss']['training'][-3:]), 'Lv': np.mean(perf['loss']['validation'][-1:]), 'lr': '{:.1e}'.format(learning_rate), 'ssim': '{:.2f}'.format(perf['ssim']['validation'][-1]), 'H': '{:.1f}'.format(np.mean(perf['entropy']['training'][-1:])), } if dcn.scale_latent: progress_dict['S'] = '{:.1f}'.format(scaling) if dcn.use_batchnorm: # Get current batch / population stats prebn = dcn.sess.run(dcn.pre_bn, feed_dict={dcn.x: batch_x}) bM = np.mean(prebn, axis=(0, 1, 2)) bV = np.var(prebn, axis=(0, 1, 2)) pM = dcn.sess.run(dcn.graph.get_tensor_by_name('{}/encoder/bn_0/moving_mean:0'.format(dcn.scoped_name))) pV = dcn.sess.run(dcn.graph.get_tensor_by_name('{}/encoder/bn_0/moving_variance:0'.format(dcn.scoped_name))) # Append summary progress_dict['MVp'] = '{:.2f}/{:.2f}'.format(np.mean(pM), np.mean(pV)) progress_dict['MVb'] = '{:.2f}/{:.2f}'.format(np.mean(bM), np.mean(bV)) # Update progress bar pbar.set_postfix(progress_dict) pbar.update(1)