def load_model(checkpoint_dir): # Instantiate model. analysis_transform = AnalysisTransform(192) synthesis_transform = SynthesisTransform(192) hyper_analysis_transform = HyperAnalysisTransform(192) hyper_synthesis_transform = HyperSynthesisTransform(192) entropy_bottleneck = tfc.EntropyBottleneck() # contruct keras model model_input = keras.layers.Input(shape=(None, None, 3)) y = analysis_transform(model_input) z = hyper_analysis_transform(abs(y)) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma = hyper_synthesis_transform(z_tilde) scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) x_tilde = synthesis_transform(y_tilde) # create a session sess = tf.Session() # load checkpoint latest = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) # create keras model compression_model = keras.models.Model( inputs=model_input, outputs=[y_likelihoods, z_likelihoods, x_tilde]) return sess, compression_model, model_input
def load_model_with_input(checkpoint_dir, input_variable): # get start variables start_vars = set(x.name for x in tf.global_variables()) # Instantiate model. analysis_transform = AnalysisTransform(192) synthesis_transform = SynthesisTransform(192) hyper_analysis_transform = HyperAnalysisTransform(192) hyper_synthesis_transform = HyperSynthesisTransform(192) entropy_bottleneck = tfc.EntropyBottleneck() # contruct keras model y = analysis_transform(input_variable) z = hyper_analysis_transform(abs(y)) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma = hyper_synthesis_transform(z_tilde) scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) x_tilde = synthesis_transform(y_tilde) # create a session sess = tf.Session() # load checkpoint latest = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir) ## we load all variables only for the compression model end_vars = tf.global_variables() compression_model_vars = [x for x in end_vars if x.name not in start_vars] tf.train.Saver(var_list=compression_model_vars).restore(sess, save_path=latest) return sess, input_variable, y_likelihoods, z_likelihoods, x_tilde
def decompress(packed_string, trained_model): string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) tensors = [string, side_string, x_shape, y_shape, z_shape] packed = PackedTensors(packed_string) arrays = packed.unpack(tensors) synthesis_transform = SynthesisTransform(192) hyper_synthesis_transform = HyperSynthesisTransform(192) entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) z_shape = tf.concat([z_shape, [192]], axis=0) z_hat = entropy_bottleneck.decompress(side_string, z_shape, channels=192) mean, sigma = hyper_synthesis_transform(z_hat) mean = mean[:, :y_shape[0], :y_shape[1], :] sigma = sigma[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mean, dtype=tf.float32) y_hat = conditional_bottleneck.decompress(string) x_hat = synthesis_transform(y_hat) sess = tf.Session() tf.train.Saver().restore(sess, save_path=trained_model) x_hat, m, s = sess.run([quantize_image(x_hat), mean, sigma], feed_dict=dict(zip(tensors, arrays))) return x_hat, m, s
def build_model(x, lmbda, mode='training', layers=None, msssim_loss=False): """Builds the compression model.""" is_training = (mode == 'training') num_pixels = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1])) if layers is None: num_filters = 192 analysis_transform = AnalysisTransform(num_filters) synthesis_transform = SynthesisTransform(num_filters) hyper_analysis_transform = HyperAnalysisTransform(num_filters) hyper_synthesis_transform = HyperSynthesisTransform(num_filters) entropy_bottleneck = tfc.EntropyBottleneck() layers = (analysis_transform, hyper_analysis_transform, entropy_bottleneck, hyper_synthesis_transform, synthesis_transform) else: analysis_transform, hyper_analysis_transform, entropy_bottleneck, \ hyper_synthesis_transform, synthesis_transform = layers y = analysis_transform(x) z = hyper_analysis_transform(y) z_tilde_hat, z_likelihoods = entropy_bottleneck(z, training=is_training) mean, sigma = hyper_synthesis_transform(z_tilde_hat) scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mean) y_tilde_hat, y_likelihoods = conditional_bottleneck(y, training=is_training) x_tilde_hat = synthesis_transform(y_tilde_hat) if mode == "testing": side_string = entropy_bottleneck.compress(z_tilde_hat) string = conditional_bottleneck.compress(y_tilde_hat) else: string = None side_string = None bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde_hat)) mse *= 255 ** 2 msssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(x_tilde_hat, x, 1)) distortion = msssim if msssim_loss else mse loss = lmbda * distortion + bpp return loss, bpp, mse, msssim, x_tilde_hat, y_tilde_hat, z_tilde_hat, \ y, z, string, side_string, layers
def decompress(args): """Decompresses an image.""" # Adapted from https://github.com/tensorflow/compression/blob/master/examples/bmshj2018.py # Read the shape information and compressed string from the binary file. string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) with open(args.input_file, "rb") as f: packed = tfc.PackedTensors(f.read()) tensors = [string, side_string, x_shape, y_shape, z_shape] arrays = packed.unpack(tensors) # Instantiate model. TODO: automate this with build_graph synthesis_transform = SynthesisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) # Decompress and transform the image back. z_shape = tf.concat([z_shape, [args.num_filters]], axis=0) z_hat = entropy_bottleneck.decompress(side_string, z_shape, channels=args.num_filters) mu, sigma = tf.split(hyper_synthesis_transform(z_hat), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive training = False if not training: # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y mu = mu[:, :y_shape[0], :y_shape[1], :] sigma = sigma[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu, dtype=tf.float32) y_hat = conditional_bottleneck.decompress(string) x_hat = synthesis_transform(y_hat) # Remove batch dimension, and crop away any extraneous padding on the bottom # or right boundaries. x_hat = x_hat[0, :x_shape[0], :x_shape[1], :] # Write reconstructed image out as a PNG file. op = write_png(args.output_file, x_hat) # Load the latest model checkpoint, and perform the above actions. with tf.Session() as sess: save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) sess.run(op, feed_dict=dict(zip(tensors, arrays)))
def load_test_model_graph(checkpoint_dir): ''' model used in test mode. (entropy_bootleneck(training=False) ''' # inputs x = tf.placeholder(tf.float32, [1, None, None, 3]) orig_x = tf.placeholder(tf.float32, [1, None, None, 3]) # Instantiate model. analysis_transform = AnalysisTransform(192) synthesis_transform = SynthesisTransform(192) hyper_analysis_transform = HyperAnalysisTransform(192) hyper_synthesis_transform = HyperSynthesisTransform(192) entropy_bottleneck = tfc.EntropyBottleneck() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = synthesis_transform(y_hat) # eval bpp num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum( tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # reconstruction metric # Bring both images back to 0..255 range. orig_x_255 = orig_x * 255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(orig_x_255, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, orig_x_255, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, orig_x_255, 255)) # session sess = tf.Session() # load graph latest = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) return sess, x, orig_x, [ string, side_string ], eval_bpp, x_hat, mse, psnr, msssim, num_pixels, y, z
def train(self, x, gamma, alpha, lmbda): """Initializes the training model""" analysis_transform = self.analysis_transform_class( self.num_filters, data_format=self.data_format) synthesis_transform = self.synthesis_transform_class( self.num_filters, data_format=self.data_format) hyper_analysis_transform = self.hyper_analysis_transform_class( self.num_filters, data_format=self.data_format) hyper_synthesis_transform = self.hyper_synthesis_transform_class( self.num_filters, data_format=self.data_format) entropy_bottleneck = tfc.EntropyBottleneck( data_format=self.data_format) # Build autoencoder. y = analysis_transform(x) z = hyper_analysis_transform(y) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma_tilde = hyper_synthesis_transform(z_tilde) conditional_bottleneck = tfc.GaussianConditional( sigma_tilde, self.scale_table) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) x_tilde = synthesis_transform(y_tilde) x_quant = quantize_tensor(x) x_tilde_quant = quantize_tensor(x_tilde) # Loss num_occupied_voxels = tf.reduce_sum(x) log_y_likelihoods = tf.log(y_likelihoods) log_z_likelihoods = tf.log(z_likelihoods) denominator = -np.log(2) * num_occupied_voxels train_mbpov_y = tf.reduce_sum(log_y_likelihoods) / denominator train_mbpov_z = tf.reduce_sum(log_z_likelihoods) / denominator self.train_mbpov = train_mbpov_y + train_mbpov_z self.train_fl = focal_loss(x, x_tilde, gamma=gamma, alpha=alpha) self.train_loss = lmbda * self.train_fl + self.train_mbpov v1_summaries(self.train_loss, train_mbpov_y, self.train_mbpov, self.train_fl, log_y_likelihoods, num_occupied_voxels, x, x_tilde, x_tilde_quant, y, y_likelihoods, y_tilde) v2_summaries(log_z_likelihoods, sigma_tilde, train_mbpov_z, z, z_likelihoods, z_tilde) binary_classification_summaries(x_quant, x_tilde_quant) self.merged_summary = tf.summary.merge_all() # Minimize loss and auxiliary loss, and execute update op. self.step = tf.train.get_or_create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) main_step = main_optimizer.minimize(self.train_loss, global_step=self.step) aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0]) self.train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
def build_graph(args, x, training=True): """ Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3]. Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest. During training we sample from box-shaped posteriors; during compression this is approximated by rounding. """ # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) # y = g_a(x) z = hyper_analysis_transform(y) # z = h_a(y) # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior # p(z_tilde) ("z_likelihoods") z_tilde, z_likelihoods = entropy_bottleneck(z, training=training) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive if not training: # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_tilde, y_likelihoods = conditional_bottleneck(y, training=training) x_tilde = synthesis_transform(y_tilde) if not training: side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input return locals()
def decompress(args): """Decompresses an image.""" # Read the shape information and compressed string from the binary file. string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) with open(args.input_file, "rb") as f: packed = tfc.PackedTensors(f.read()) tensors = [string, side_string, x_shape, y_shape, z_shape] arrays = packed.unpack(tensors) # Instantiate model. synthesis_transform = SynthesisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) # Decompress and transform the image back. z_shape = tf.concat([z_shape, [args.num_filters]], axis=0) z_hat = entropy_bottleneck.decompress(side_string, z_shape, channels=args.num_filters) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, dtype=tf.float32) y_hat = conditional_bottleneck.decompress(string) x_hat = synthesis_transform(y_hat) # Remove batch dimension, and crop away any extraneous padding on the bottom # or right boundaries. x_hat = x_hat[0, :x_shape[0], :x_shape[1], :] # Write reconstructed image out as a PNG file. op = write_png(args.output_file, x_hat) # Load the latest model checkpoint, and perform the above actions. with tf.Session() as sess: latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) sess.run(op, feed_dict=dict(zip(tensors, arrays)))
def decompress(args): """Decompresses an image.""" # Read the shape information and compressed string from the binary file. string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) with open('/media/xproject/file/Surige/compression-master/examples/rnn_baseline/recon/recon.bin', "rb") as f: packed = tfc.PackedTensors(f.read()) tensors = [string, side_string, x_shape, y_shape, z_shape] arrays = packed.unpack(tensors) # Add a batch dimension, then decompress and transform the image back. d = decoder(args.batchsize, height=x_shape[0], width=x_shape[1]) hd = HyperDecoder(args.batchsize, height=x_shape[0] // 16, width=x_shape[1] // 16) entropy_bottleneck = tfc.EntropyBottleneck(name='entropy_iter', dtype=tf.float32) # Decompress and transform the image back. z_shape = tf.concat([z_shape, [args.num_filters]], axis=0) z_hat = entropy_bottleneck.decompress( side_string, z_shape, channels=args.num_filters) sigma = hd.hyper_decode(z_hat) sigma = sigma[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional( sigma, scale_table, dtype=tf.float32) y_hat = conditional_bottleneck.decompress(string) x_hat = d.decode(y_hat) # Remove batch dimension, and crop away any extraneous padding on the bottom # or right boundaries. x_hat = x_hat[0, :x_shape[0], :x_shape[1], :] # Write reconstructed image out as a PNG file. op = write_png(args.output_file, x_hat) # Load the latest model checkpoint, and perform the above actions. with tf.Session() as sess: latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) sess.run(op, feed_dict=dict(zip(tensors, arrays)))
def compress(self, x_shape): """Initializes the compression model""" analysis_transform = self.analysis_transform_class( self.num_filters, data_format=self.data_format) synthesis_transform = self.synthesis_transform_class( self.num_filters, data_format=self.data_format) hyper_analysis_transform = self.hyper_analysis_transform_class( self.num_filters, data_format=self.data_format) hyper_synthesis_transform = self.hyper_synthesis_transform_class( self.num_filters, data_format=self.data_format) entropy_bottleneck = tfc.EntropyBottleneck( data_format=self.data_format) self.x = tf.placeholder(tf.float32, shape=x_shape) y = analysis_transform(self.x) z = hyper_analysis_transform(y) z_string = entropy_bottleneck.compress(z) z_hat = entropy_bottleneck.decompress(z_string, tf.shape(z)[1:], channels=self.num_filters) sigma_hat = hyper_synthesis_transform(z_hat) conditional_bottleneck = tfc.GaussianConditional(sigma_hat, self.scale_table, dtype=tf.float32) y_string = conditional_bottleneck.compress(y) y_hat = conditional_bottleneck.decompress(y_string) self.x_hat = synthesis_transform(y_hat) self.strings = (y_string, z_string) self.debug_tensors = { **{ 'z_hat': z_hat, 'sigma_hat': sigma_hat }, **conditional_bottleneck.dbg_dec, **{ 'y_hat': y_hat, 'x_hat': self.x_hat } }
def decompress(self): """Initializes the decompression model""" synthesis_transform = self.synthesis_transform_class( self.num_filters, data_format=self.data_format) hyper_synthesis_transform = self.hyper_synthesis_transform_class( self.num_filters, data_format=self.data_format) entropy_bottleneck = tfc.EntropyBottleneck( data_format=self.data_format) y_string_t = tf.placeholder(tf.string) z_string_t = tf.placeholder(tf.string) self.x_shape_t = tf.placeholder(tf.int32, shape=(3, )) self.strings_t = [y_string_t, z_string_t] z_shape_t = add_channels(self.x_shape_t // 16, self.num_filters, self.data_format) z_hat = entropy_bottleneck.decompress(z_string_t, z_shape_t, channels=self.num_filters) sigma_hat = hyper_synthesis_transform(z_hat) conditional_bottleneck = tfc.GaussianConditional(sigma_hat, self.scale_table, dtype=tf.float32) y_hat = conditional_bottleneck.decompress(y_string_t) x_hat = synthesis_transform(y_hat) self.x_hat = x_hat self.debug_tensors = { **{ 'z_hat': z_hat, 'sigma_hat': sigma_hat }, **conditional_bottleneck.dbg_dec, **{ 'y_hat': y_hat, 'x_hat': self.x_hat } }
def test_compress(args): """Compresses an image.""" # Load input image and add batch dimension. fn = tf.placeholder(tf.string, []) x = read_png(fn) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) lmbda_level = tf.placeholder(tf.int32, []) lmbda_onehot = tf.one_hot(tf.reshape(lmbda_level,[1]), depth=8) lmbda = 0.1 * tf.pow(2.0, tf.cast(lmbda_level, tf.float32) - 6.0) actives = [tf.constant(256), tf.constant(256), tf.constant(256), tf.constant(3)] # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters, lmbda_onehot) synthesis_transform = SynthesisTransform(args.num_filters, lmbda_onehot, actives) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, lmbda_onehot) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, lmbda_onehot) entropy_bottleneck = tfc.EntropyBottleneck() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x_im = x*255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat_im = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x_im, x_hat_im)) psnr = tf.squeeze(tf.image.psnr(x_hat_im, x_im, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat_im, x_im, 255)) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) actives = [tf.placeholder(tf.int32, []), tf.placeholder(tf.int32, []), tf.placeholder(tf.int32, []), tf.constant(3)] synthesis_transform.actives = actives x_hat = synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x_im = x*255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat_im = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x_im, x_hat_im)) psnr = tf.squeeze(tf.image.psnr(x_hat_im, x_im, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat_im, x_im, 255)) #ac_list = [(i*8,j*8,k*8) for i in range(4,33) for j in range(4,33) for k in range(4,33)] ac_list = [(i,j,k) for i in [32, 88, 128, 192, 256] for j in [32, 72, 128, 168, 256] for k in [32, 64, 96, 144, 256]] f = open("f5.csv", "w") print("hash,level,bpp,mse", file=f) for acs in ac_list: if acs[1]==32 and acs[2]==32: print(acs[0]) for i in np.arange(0,8): count_bpp = 0 count_mse =0 for filename in glob.glob("kodak/*.png"): v_lmbda_level, v_eval_bpp, v_mse = sess.run( [lmbda_level, eval_bpp, mse], feed_dict={ fn: filename, lmbda_level: i, actives[0]: acs[0], actives[1]: acs[1], actives[2]: acs[2]}) count_bpp += v_eval_bpp count_mse += v_mse print("%03d%03d%03d, %d, %.4f, %.4f"%(acs[0], acs[1], acs[2], v_lmbda_level, count_bpp/24.0, count_mse/24.0), file=f) f.close()
def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo: """Apply this layer to code `latents`. Args: latents: Tensor of latent values to code. image_shape: The [height, width] of a reference frame. mode: The training, evaluation or validation mode of the model. Returns: A HyperInfo tuple. """ training = (mode == ModelMode.TRAINING) validation = (mode == ModelMode.VALIDATION) latent_shape = tf.shape(latents)[1:-1] hyper_latents = self._analysis(latents, training=training) # Model hyperprior distributions and entropy encode/decode hyper-latents. side_info = self._side_entropy_model(hyper_latents, image_shape=image_shape, mode=mode, training=training) hyper_decoded = side_info.decoded scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) latent_scales = self._synthesis_scale(hyper_decoded, training=training) latent_means = self._synthesis_mean(tf.cast(hyper_decoded, tf.float32), training=training) if not (training or validation): latent_scales = latent_scales[:, :latent_shape[0], : latent_shape[1], :] latent_means = latent_means[:, :latent_shape[0], : latent_shape[1], :] conditional_entropy_model = tfc.GaussianConditional( latent_scales, scale_table, mean=latent_means, name="conditional_entropy_model") entropy_info = estimate_entropy(conditional_entropy_model, latents, spatial_shape=image_shape) compressed = None if training: latents_decoded = _quantize(latents, latent_means) elif validation: latents_decoded = entropy_info.quantized else: compressed = conditional_entropy_model.compress(latents) latents_decoded = conditional_entropy_model.decompress(compressed) info = HyperInfo(decoded=latents_decoded, latent_shape=latent_shape, hyper_latent_shape=side_info.latent_shape, nbpp=entropy_info.nbpp, side_nbpp=side_info.total_nbpp, total_nbpp=entropy_info.nbpp + side_info.total_nbpp, qbpp=entropy_info.qbpp, side_qbpp=side_info.total_qbpp, total_qbpp=entropy_info.qbpp + side_info.total_qbpp, bitstring=compressed, side_bitstring=side_info.bitstring) tf.summary.scalar("bpp/total/noisy", info.total_nbpp) tf.summary.scalar("bpp/total/quantized", info.total_qbpp) return info
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Initial values for optimization y_init = analysis_transform(x) z_init = hyper_analysis_transform(y_init) y = tf.placeholder('float32', y_init.shape) T = tf.placeholder('float32', shape=[], name='temperature') y_floor = tf.floor(y) y_ceil = tf.ceil(y) y_bds = tf.stack([y_floor, y_ceil], axis=-1) epsilon = 1e-5 ry_logits = tf.stack( [ -tf.math.atanh( tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh( tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T ], axis=-1 ) # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0 ry = tf.nn.softmax(ry_logits, axis=-1) y_tilde = tf.reduce_sum(y_bds * ry, axis=-1) # inner product in last dim x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior # # p(z_tilde) ("z_likelihoods") # z_tilde, z_likelihoods = entropy_bottleneck(z, training=training) z = tf.placeholder('float32', z_init.shape) z_floor = tf.floor(z) z_ceil = tf.ceil(z) z_bds = tf.stack([z_floor, z_ceil], axis=-1) rz_logits = tf.stack([ -tf.math.atanh(tf.clip_by_value(z - z_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh(tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon)) / T ], axis=-1) # last dim are logits for DOWN or UP rz = tf.nn.softmax(rz_logits, axis=-1) z_tilde = tf.reduce_sum(z_bds * rz, axis=-1) # inner product in last dim # # We have to manually call entropy_bottleneck.build because we don't directly call entropy_bottleneck like we did # # with 'z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)' during training # # UPDATE: this doesn't quite work, as the resulting variables don't have the proper name scope (will just be named # # "matrix_0", "bias_0", etc., instead of "entropy_bottleneck/matrix_0", "entropy_bottleneck/bias_0" as would with # # calling entropy_bottleneck on tensor, which breaks model loading (will get "Key bias_0 not found in checkpoint.. # # tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due # # to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not # # altered the graph expected based on the checkpoint."). # entropy_bottleneck.build(z_tilde.shape) _ = entropy_bottleneck( z, training=False ) # dummy call to ensure entropy_bottleneck is properly built z_likelihoods = entropy_bottleneck._likelihood(z_tilde) # p(\tilde z) if entropy_bottleneck.likelihood_bound > 0: likelihood_bound = entropy_bottleneck.likelihood_bound z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp' ] eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp] all_results_arrs = {key: [] for key in eval_fields } # append across all batches log_itv = 100 if save_opt_record: log_itv = 10 rd_lr = 0.005 rd_opt_its = 2000 annealing_rate = 4e-3 T_ub = 0.2 def annealed_temperature(t, r, ub, lb=1e-8, backend=np): # Using the exp schedule from section 4.2 of Jang et. al., ICLR2017 if backend is None: return min(max(np.exp(-r * t), lb), ub) else: return backend.minimum( backend.maximum(backend.exp(-r * t), lb), ub) from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur, z_cur = sess.run([y_init, z_init], feed_dict=x_feed_dict) # np arrays adam_optimizer = Adam(lr=rd_lr) opt_record = { 'its': [], 'T': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): temperature = annealed_temperature(it, r=annealing_rate, ub=T_ub) grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z: z_cur, **x_feed_dict, T: temperature }) y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() if args.verbose: bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run( [train_bpp, psnr, rd_loss], feed_dict={ y_tilde: np.round(y_cur), z_tilde: np.round(z_cur), **x_feed_dict }) psnr_after_rounding = psnr_after_rounding.mean() print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_, rd_loss_after_rounding, bpp_after_rounding, psnr_after_rounding)) opt_record['rd_loss_after_rounding'].append( rd_loss_after_rounding) else: print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_)) opt_record['its'].append(it) opt_record['T'].append(temperature) opt_record['rd_loss'].append(obj) print() y_tilde_cur = np.round( y_cur) # this is the latents we end up transmitting z_tilde_cur = np.round(z_cur) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_tilde_cur, z_tilde: z_tilde_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension # save RD evaluation results prefix = 'rd' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) if save_opt_record: # save optimization record prefix = 'opt' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **opt_record) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def test_decompress(args): """Decompresses an image.""" # Read the shape information and compressed string from the binary file. string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) with open(args.input_file, "rb") as f: packed = tfc.PackedTensors(f.read()) tensors = [string, side_string, x_shape, y_shape, z_shape] arrays = packed.unpack(tensors) # Instantiate model. synthesis_transform = SynthesisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) # Decompress and transform the image back. z_shape = tf.concat([z_shape, [args.num_filters]], axis=0) z_hat = entropy_bottleneck.decompress(side_string, z_shape, channels=args.num_filters) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, dtype=tf.float32) y_hat_all = conditional_bottleneck.decompress(string) x = read_png("kodak/kodim01.png") x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) x *= 255 active = 192 y_hat = y_hat_all[:, :, :, :active] x_hat = synthesis_transform(y_hat) x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) #x_hat = x_hat[0, :x_shape[0], :x_shape[1], :] #op = write_png(args.output_file, x_hat) sess = tf.Session() latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) #sess.run(op, feed_dict=dict(zip(tensors, arrays))) #vmse, vpsnr, vmsssim = sess.run([mse, psnr, msssim], feed_dict=dict(zip(tensors, arrays))) #print(vmse, vpsnr, vmsssim) for active in range(192, 0, -8): y_hat = y_hat_all[:, :, :, :active] x_hat = synthesis_transform(y_hat) x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) vmse, vpsnr, vmsssim = sess.run([mse, psnr, msssim], feed_dict=dict(zip(tensors, arrays))) print(active, vmse, vpsnr, vmsssim)
def test_compress(args): """Compresses an image.""" # Load input image and add batch dimension. x = read_png(args.input_file) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) step = 0.1 lmbda_log_dist = np.arange(0, 7, step) lmbda_log_dist = tf.constant(lmbda_log_dist, dtype=tf.float32) s = tf.data.Dataset.from_tensor_slices(lmbda_log_dist) lmbda_log = s.make_one_shot_iterator().get_next() # levels lmbda = 0.1 * tf.pow(2.0, lmbda_log - 6.0) # true value # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters, lmbda_log) synthesis_transform = SynthesisTransform(args.num_filters, lmbda_log) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, lmbda_log) hyper_synthesis_transform = HyperSynthesisTransform( args.num_filters, lmbda_log) entropy_bottleneck = tfc.EntropyBottleneck() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum( tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x *= 255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) tensors = [ string, side_string, tf.shape(x)[1:-1], tf.shape(y)[1:-1], tf.shape(z)[1:-1] ] for i in np.arange(0, 7, step): arrays, v_eval_bpp, v_mse, v_psnr, v_msssim, v_num_pixels = sess.run( [tensors, eval_bpp, mse, psnr, msssim, num_pixels]) packed = tfc.PackedTensors() packed.pack(tensors, arrays) with open(args.output_file, "wb") as f: f.write(packed.string) # The actual bits per pixel including overhead. bpp = len(packed.string) * 8 / v_num_pixels print(bpp, v_eval_bpp, v_mse, v_psnr, v_msssim)
def test_train(args): """Trains the model.""" if args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. with tf.device("/cpu:0"): train_files = glob.glob(args.train_glob) if not train_files: raise RuntimeError( "No training images found with glob '{}'.".format( args.train_glob)) train_dataset = tf.data.Dataset.from_tensor_slices(train_files) train_dataset = train_dataset.shuffle( buffer_size=len(train_files)).repeat() train_dataset = train_dataset.map( read_png, num_parallel_calls=args.preprocess_threads) train_dataset = train_dataset.map( lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3))) train_dataset = train_dataset.batch(args.batchsize) train_dataset = train_dataset.prefetch(32) num_pixels = args.batchsize * args.patchsize**2 # Get training patch from dataset. x = train_dataset.make_one_shot_iterator().get_next() # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = DynamicEntropyBottleneck(name="entropy_bottleneck") # Build autoencoder and hyperprior. y = analysis_transform(x) z = hyper_analysis_transform(abs(y)) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma = hyper_synthesis_transform(z_tilde) scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) incep = InterceptNorate(32, 256, 1) y_incep = incep(y_tilde) x_tilde = synthesis_transform(y_incep) train_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum( tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) * (255**2) def f(W): return 10**(-W / 32.0 * 0.6 + 4.1) dist = np.zeros(256) dist[:32] = f(32) for i in range(32, 256): dist[i] = f(i + 1) lambda_dist = tf.constant(dist, dtype=tf.float32) train_bpp_dist = ( tf.reduce_sum(tf.log(y_likelihoods), axis=(0,1,2)) + \ tf.reduce_sum(tf.log(z_likelihoods), axis=(0,1,2)) ) / (-np.log(2) * num_pixels) # The rate-distortion cost. train_loss = tf.reduce_sum(lambda_dist * train_bpp_dist) + train_mse # Minimize loss and auxiliary loss, and execute update op. with tf.Session() as sess: latest = tf.train.latest_checkpoint(checkpoint_dir="./e5") tf.train.Saver().restore(sess, save_path=latest) step = tf.train.create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) main_step = main_optimizer.minimize(train_loss, global_step=step) aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0]) train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) tf.summary.scalar("loss", train_loss) tf.summary.scalar("bpp", train_bpp) tf.summary.scalar("mse", train_mse) tf.summary.image("original", quantize_image(x)) tf.summary.image("reconstruction", quantize_image(x_tilde)) hooks = [ tf.train.StopAtStepHook(last_step=args.last_step), tf.train.NanTensorHook(train_loss), ] with tf.train.MonitoredTrainingSession(hooks=hooks, checkpoint_dir=args.checkpoint_dir, save_checkpoint_secs=300, save_summaries_secs=60) as sess: while not sess.should_stop(): sess.run(train_op)
def decompress(input_bin_path, input_res_path, output_img_path, ckp_dir, tau): with tf.device('/cpu:0'): # Load bin and res string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) with open(input_bin_path, "rb") as f: packed = tfc.PackedTensors(f.read()) tensors = [string, side_string, x_shape, y_shape, z_shape] arrays = packed.unpack(tensors) # instantiate model decoder = nll_codec.Decoder(192) hyper_decoder_sigma = nll_codec.HyperDecoder(192) hyper_decoder_mu = nll_codec.HyperDecoder(192) entropy_parameters_sigma = nll_codec.EntropyParameters(192) entropy_parameters_mu = nll_codec.EntropyParameters(192) entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) res_compressor = nll_codec.ResidualCompressor(128, 5) masked_conv = nll_codec.MaskedConv2d("A", 64, (5, 5), padding="VALID") res_compressor_cond = bc.ResidualCompressor_cond(128, 5) # build decoder z_shape = tf.concat([z_shape, [192]], axis=0) z_hat_decode = entropy_bottleneck.decompress( side_string, z_shape, channels=192) # decode z (including dequantization) psi_sigma = hyper_decoder_sigma(z_hat_decode) psi_mu = hyper_decoder_mu(z_hat_decode) sigma = entropy_parameters_sigma(psi_sigma) mu = entropy_parameters_mu(psi_mu) sigma = sigma[:, :y_shape[0], :y_shape[1], :] mu = mu[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp( np.linspace(np.log(SCALE_MIN), np.log(SCALE_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu, dtype=tf.float32) y_hat_decode = conditional_bottleneck.decompress( string) # decode y (including dequantization) x_hat, res_prior = decoder(y_hat_decode) x_hat = x_hat[:, :x_shape[0], :x_shape[1], :] x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.math.floor(x_hat * 255 + 0.5) res_prior = res_prior[:, :x_shape[0], :x_shape[1], :] tau_list = tf.constant([int(tau - 1)], tf.int32) cond = tf.one_hot(tau_list, 5) num_pixels = tf.cast(tf.reduce_prod(x_shape[:-1]), dtype=tf.float32) res_q_patch = tf.placeholder(dtype=tf.float32, shape=(1, 5, 5, 3)) res_prior_channel_num = 64 res_prior_patch = tf.placeholder(dtype=tf.float32, shape=(1, 1, 1, res_prior_channel_num)) res_q_vector = tf.placeholder(dtype=tf.float32, shape=(1, 1, 1, 3)) bin_sz = 2 * tau + 1 pmf_length = int(510 // bin_sz + 1) pmf_end = (255 // bin_sz) * bin_sz context = masked_conv(res_q_patch) res_prior_context = tf.concat([res_prior_patch, context], 3) bias_correction = True if bias_correction and int(tau) > 0: res_mu, res_log_sigma, res_pi, res_lambda = res_compressor_cond( res_prior_context, cond) else: res_mu, res_log_sigma, res_pi, res_lambda = res_compressor( res_prior_context) res_mu_tiled = tf.tile(res_mu, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_log_sigma_tiled = tf.tile( res_log_sigma, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_pi_tiled = tf.tile(res_pi, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_lambda_tiled = tf.tile( res_lambda, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_bottleneck = lmm.LogisticMixtureModel(res_mu_tiled, res_log_sigma_tiled, res_pi_tiled, res_lambda_tiled) res_pmf = res_bottleneck.pmf_tau(res_q_vector, tau) with tf.Session() as sess: latest = tf.train.latest_checkpoint(checkpoint_dir=ckp_dir) tf.train.Saver().restore(sess, save_path=latest) # lossy image decoding print("Lossy Image Decoding Start.") res_prior_out, x_out, num_pixels_out, x_shape_out = sess.run( [res_prior, x_hat, num_pixels, x_shape], feed_dict=dict(zip(tensors, arrays))) print("Lossy Image Decoding Finish.") k_sz = 5 pad_sz = 2 x_h, x_w = x_shape_out x_c = 3 res_q_dec_padded = np.zeros( (1, x_h + 2 * pad_sz, x_w + 2 * pad_sz, x_c)) decoder = RangeDecoder(input_res_path) print('Residual Decoding Start.') for h_idx in range(x_h): for w_idx in range(x_w): res_q_extracted = res_q_dec_padded[:, h_idx:h_idx + k_sz, w_idx:w_idx + k_sz, :] res_prior_extracted = res_prior_out[:, h_idx, w_idx, :].reshape( 1, 1, 1, res_prior_channel_num ) for c_idx in range(x_c): res_q_vector_extracted = res_q_dec_padded[:, h_idx + pad_sz, w_idx + pad_sz, :].reshape( 1, 1, 1, 3) res_pmf_out = sess.run(res_pmf, feed_dict={ res_q_patch: res_q_extracted, res_prior_patch: res_prior_extracted, res_q_vector: res_q_vector_extracted }) c_pmf = res_pmf_out[:, 0, 0, c_idx] c_pmf = np.clip(c_pmf, 1.0 / 65025, 1.0) c_pmf = c_pmf / np.sum(c_pmf) cumFreq = np.floor( np.append([0.], np.cumsum(c_pmf)) * 65536. + 0.5).astype(np.int32).tolist() dataRec = decoder.decode(1, cumFreq) res_q_dec_padded[0, h_idx + pad_sz, w_idx + pad_sz, c_idx] = dataRec[0] * bin_sz - pmf_end print("Decode Finish.") decoder.close() res_q_dec = res_q_dec_padded[:, pad_sz:x_h + pad_sz, pad_sz:x_w + pad_sz, :] x_rec = np.clip(np.squeeze(x_out + res_q_dec), 0, 255) im = Image.fromarray(np.uint8(x_rec)) im.save(output_img_path) return x_rec
def compress(input_path, output_bin_path, output_res_path, ckp_dir, tau): with tf.device('/cpu:0'): # Load and Pad Image x = read_png(input_path) mod = tf.constant([64, 64, 1], dtype=tf.int32) div = tf.cast(tf.math.ceil(tf.math.truediv(tf.shape(x), mod)), tf.int32) paddings = tf.math.subtract(tf.math.multiply(div, mod), tf.shape(x)) paddings = tf.expand_dims(paddings, 1) paddings = tf.concat( [tf.convert_to_tensor(np.zeros((3, 1)), dtype=tf.int32), paddings], axis=1) x_pad = tf.pad(x, paddings, "REFLECT") x_pad = tf.expand_dims(x_pad, 0) x_pad.set_shape([1, None, None, 3]) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) x_norm = x_pad / 255 # instantiate model encoder = nll_codec.Encoder(192) decoder = nll_codec.Decoder(192) hyper_encoder = nll_codec.HyperEncoder(192) hyper_decoder_sigma = nll_codec.HyperDecoder(192) hyper_decoder_mu = nll_codec.HyperDecoder(192) entropy_parameters_sigma = nll_codec.EntropyParameters(192) entropy_parameters_mu = nll_codec.EntropyParameters(192) entropy_bottleneck = tfc.EntropyBottleneck() res_compressor = nll_codec.ResidualCompressor(128, 5) masked_conv = nll_codec.MaskedConv2d("A", 64, (5, 5), padding="VALID") res_compressor_cond = bc.ResidualCompressor_cond(128, 5) # build model and encode/decode y = encoder(x_norm) y_shape = tf.shape(y) z = hyper_encoder(y) side_string = entropy_bottleneck.compress( z) # encode z (including quantization) z_hat_decode = entropy_bottleneck.decompress( side_string, tf.shape(z)[1:], channels=192) # decode z (including dequantization) psi_sigma = hyper_decoder_sigma(z_hat_decode) psi_mu = hyper_decoder_mu(z_hat_decode) sigma = entropy_parameters_sigma(psi_sigma) mu = entropy_parameters_mu(psi_mu) scale_table = np.exp( np.linspace(np.log(SCALE_MIN), np.log(SCALE_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) string = conditional_bottleneck.compress( y) # encode y (including quantization) y_hat_decode = conditional_bottleneck.decompress( string) # decode y (including dequantization) x_hat, res_prior = decoder(y_hat_decode) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.math.floor(x_hat * 255 + 0.5) res_prior = res_prior[:, :x_shape[1], :x_shape[2], :] res = x - x_hat res_q = tf.where(res >= 0, (2 * tau + 1) * tf.math.floor( (res + tau) / (2 * tau + 1)), (2 * tau + 1) * tf.math.ceil( (res - tau) / (2 * tau + 1))) tau_list = tf.constant([int(tau - 1)], tf.int32) cond = tf.one_hot(tau_list, 5) num_pixels = tf.cast(tf.reduce_prod(x_shape[:-1]), dtype=tf.float32) res_q_patch = tf.placeholder(dtype=tf.float32, shape=(1, 5, 5, 3)) res_prior_channel_num = 64 res_prior_patch = tf.placeholder(dtype=tf.float32, shape=(1, 1, 1, res_prior_channel_num)) res_q_vector = tf.placeholder(dtype=tf.float32, shape=(1, 1, 1, 3)) bin_sz = 2 * tau + 1 pmf_length = int(510 // bin_sz + 1) pmf_end = (255 // bin_sz) * bin_sz context = masked_conv(res_q_patch) res_prior_context = tf.concat([res_prior_patch, context], 3) bias_correction = True if bias_correction and int(tau) > 0: res_mu, res_log_sigma, res_pi, res_lambda = res_compressor_cond( res_prior_context, cond) else: res_mu, res_log_sigma, res_pi, res_lambda = res_compressor( res_prior_context) res_mu_tiled = tf.tile(res_mu, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_log_sigma_tiled = tf.tile( res_log_sigma, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_pi_tiled = tf.tile(res_pi, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_lambda_tiled = tf.tile( res_lambda, tf.constant([pmf_length, 1, 1, 1], tf.int32)) res_bottleneck = lmm.LogisticMixtureModel(res_mu_tiled, res_log_sigma_tiled, res_pi_tiled, res_lambda_tiled) res_pmf = res_bottleneck.pmf_tau(res_q_vector, tau) # MSE eval_mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) # PSNR eval_psnr = 10 * tf.math.log(255**2 / eval_mse) / tf.math.log(10.0) # max abs diff eval_max_abs_diff = tf.reduce_max(tf.abs(tf.subtract(x, x_hat))) with tf.Session() as sess: latest = tf.train.latest_checkpoint(checkpoint_dir=ckp_dir) tf.train.Saver().restore(sess, save_path=latest) tensors = [ string, side_string, tf.shape(x)[1:-1], tf.shape(y)[1:-1], tf.shape(z)[1:-1] ] arrays = sess.run(tensors) # write binary file packed = tfc.PackedTensors() packed.pack(tensors, arrays) with open(output_bin_path, "wb") as f: f.write(packed.string) # Lossy Image Encoding print("Lossy Image Encoding Start.") res_prior_out, res_q_out, _, x_org, x_out, lossy_mse, lossy_psnr, lossy_max_abs_diff, num_pixels_out, x_shape_out = sess.run( [ res_prior, res_q, string, x, x_hat, eval_mse, eval_psnr, eval_max_abs_diff, num_pixels, x_shape ]) print("Lossy Image Encoding Finish.") k_sz = 5 pad_sz = 2 _, x_h, x_w, x_c = x_shape_out res_q_padded = np.pad(res_q_out, ((0, 0), (pad_sz, pad_sz), (pad_sz, pad_sz), (0, 0)), 'constant') encoder = RangeEncoder(output_res_path) print('Residual Encoding Start.') for h_idx in range(x_h): for w_idx in range(x_w): res_q_extracted = res_q_padded[:, h_idx:h_idx + k_sz, w_idx:w_idx + k_sz, :] res_prior_extracted = res_prior_out[:, h_idx, w_idx, :].reshape( 1, 1, 1, res_prior_channel_num ) res_q_vector_extracted = res_q_out[:, h_idx, w_idx, :].reshape( 1, 1, 1, 3) res_pmf_out = sess.run(res_pmf, feed_dict={ res_q_patch: res_q_extracted, res_prior_patch: res_prior_extracted, res_q_vector: res_q_vector_extracted }) res_q_vector_extracted = ( res_q_vector_extracted[0, 0, 0, :] + pmf_end) // bin_sz for c_idx in range(x_c): c_pmf = res_pmf_out[:, 0, 0, c_idx] c_pmf = np.clip(c_pmf, 1.0 / 65025, 1.0) c_pmf = c_pmf / np.sum(c_pmf) cumFreq = np.floor( np.append([0.], np.cumsum(c_pmf)) * 65536. + 0.5).astype(np.int32).tolist() encoder.encode([int(res_q_vector_extracted[c_idx])], cumFreq) print("Encoding Finish.") encoder.close() print("Lossy MSE:{}, Lossy PSNR:{}, Lossy max_abs_diff:{}".format( lossy_mse, lossy_psnr, lossy_max_abs_diff)) img_sz_out = os.path.getsize(output_bin_path) res_sz_out = os.path.getsize(output_res_path) eval_sz_out = img_sz_out + res_sz_out img_bpsp = os.path.getsize(output_bin_path) * 8 / (x_c * x_h * x_w) res_bpsp = os.path.getsize(output_res_path) * 8 / (x_c * x_h * x_w) eval_bpsp = img_bpsp + res_bpsp print("tau:{}, bpsp:{}, img_bpsp:{}, res_bpsp:{}".format( tau, eval_bpsp, img_bpsp, res_bpsp)) x_rec = np.clip(np.squeeze(x_out + res_q_out), 0, 255) max_abs_diff = np.amax(np.abs(x_org - x_rec)) mse = np.mean((x_org - x_rec)**2) psnr = 10 * np.log10(255**2 / mse) print("Max abs diff:{}, NLL MSE:{}, NLL PSNR:{}".format( max_abs_diff, mse, psnr)) return eval_sz_out, img_sz_out, res_sz_out
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Initial values for optimization y_init = analysis_transform(x) z_init = hyper_analysis_transform(y_init) y = tf.placeholder('float32', y_init.shape) y_tilde = y + tf.random.uniform(tf.shape(y), -0.5, 0.5) z = tf.placeholder('float32', z_init.shape) # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior # p(z_tilde) ("z_likelihoods") z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) z_hat = entropy_bottleneck._quantize( z, 'dequantize') # rounded (with median centering) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) y_hat = conditional_bottleneck._quantize( y, 'dequantize') # rounded (with mean centering) x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp' ] eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp] all_results_arrs = {key: [] for key in eval_fields } # append across all batches log_itv = 100 if save_opt_record: log_itv = 10 rd_lr = 0.005 rd_opt_its = 2000 from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur, z_cur = sess.run([y_init, z_init], feed_dict=x_feed_dict) # np arrays adam_optimizer = Adam(lr=rd_lr) opt_record = { 'its': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z: z_cur, **x_feed_dict }) y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() if args.verbose: y_hat_, z_hat_ = sess.run([y_hat, z_hat], feed_dict={ y: y_cur, z: z_cur }) bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run( [train_bpp, psnr, rd_loss], feed_dict={ y_tilde: y_hat_, z_tilde: z_hat_, **x_feed_dict }) psnr_after_rounding = psnr_after_rounding.mean() print( 'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f' % (it, obj, mse_, train_bpp_, psnr_, rd_loss_after_rounding, bpp_after_rounding, psnr_after_rounding)) opt_record['rd_loss_after_rounding'].append( rd_loss_after_rounding) else: print( 'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, obj, mse_, train_bpp_, psnr_)) opt_record['its'].append(it) opt_record['rd_loss'].append(obj) print() # this is the latents we end up transmitting y_hat_, z_hat_ = sess.run([y_hat, z_hat], feed_dict={ y: y_cur, z: z_cur }) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_hat_, z_tilde: z_hat_, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension # save RD evaluation results prefix = 'rd' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) if save_opt_record: # save optimization record prefix = 'opt' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **opt_record) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def train(args): """Trains the model.""" if args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. with tf.device("/cpu:0"): train_files = glob.glob(args.train_glob) if not train_files: raise RuntimeError( "No training images found with glob '{}'.".format(args.train_glob)) train_dataset = tf.data.Dataset.from_tensor_slices(train_files) train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat() train_dataset = train_dataset.map( read_png, num_parallel_calls=args.preprocess_threads) train_dataset = train_dataset.map( lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3))) train_dataset = train_dataset.batch(args.batchsize) train_dataset = train_dataset.prefetch(32) num_pixels = args.batchsize * args.patchsize ** 2 lmbda = [0.01, 0.02, 0.04, 0.08] iter = [0,1,2,3] # Get training patch from dataset. inputs = train_dataset.make_one_shot_iterator().get_next() x = inputs - 0.5 e = encoder(args.batchsize, is_training=True) d = decoder(args.batchsize) he = HyperEncoder(args.batchsize, is_training=True) hd = HyperDecoder(args.batchsize) #iterations # Build autoencoder and hyperprior. entropy_bottlenecks = [] train_loss = 0 Train_BPP = 0 Train_BPP1 = [] Train_MSE = 0 Train_MSE1 = [] output = tf.zeros_like(x) + 0.5 train_ops = [] psnr_mul = [] for i, lmb in zip(iter, lmbda): y = e.encode(x) z = he.hyper_encode(abs(y)) entropy_bottleneck = tfc.EntropyBottleneck(name='entropy_iter'+ str(i)) entropy_bottlenecks.append(entropy_bottleneck) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma = hd.hyper_decode(z_tilde) scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, name='conditional'+ str(i)) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) x_tilde = d.decode(y_tilde) # Total number of bits divided by number of pixels. train_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. train_mse *= 255 ** 2 # The rate-distortion cost. Train_BPP += train_bpp Train_MSE += train_mse Train_BPP1.append(train_bpp) Train_MSE1.append(train_mse) train_loss += (lmb * train_mse + Train_BPP) #residual x = x - x_tilde #output output += x_tilde # Minimize loss and auxiliary loss, and execute update op. step = tf.train.create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) main_step = main_optimizer.minimize(train_loss, global_step=step) aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) for i in range(args.iter): aux_step = aux_optimizer.minimize(entropy_bottlenecks[i].losses[0]) train_op = tf.group(main_step, aux_step, entropy_bottlenecks[i].updates[0]) train_ops.append(train_op) tf.summary.scalar("loss", train_loss) tf.summary.scalar("bpp", Train_BPP) tf.summary.scalar("mse", Train_MSE) tf.summary.image("original", quantize_image(inputs)) tf.summary.image("reconstruction", quantize_image(output)) class LoggerHook(tf.train.SessionRunHook): """ print training information """ def begin(self): self.step = -1 self.start_time = time.time() def before_run(self, run_context): self.step += 1 # this function called automatically during training # return all training information #print(tf.train.SessionRunArgs(inputs)) return tf.train.SessionRunArgs([train_loss, Train_BPP, Train_BPP1, Train_MSE, Train_MSE1, inputs, output]) def after_run(self, run_context, run_values): # step interval display_step = 50 if self.step % display_step == 0: current_time = time.time() duration = current_time - self.start_time self.start_time = current_time # return the results of before_run(), which is loss loss, bpp, bpp1, mse, mse1, original, compressed_img = run_values.results print(bpp1, mse1) original *=255.0 compressed_img = np.array(np.clip((compressed_img+0.5)*255.0, 0.0, 255.0), dtype=np.uint8) ms_ssim = msssim(compressed_img,original) psnr = 20 * math.log10( 255.0 / math.sqrt(mse)) for i in range(args.iter): psnrs = 20 * math.log10( 255.0 / math.sqrt(mse1[i])) psnr_mul.append(psnrs) print(psnr_mul) psnr_mul.clear() # samples per second examples_per_sec = display_step * args.batchsize / duration # 每batch使用的时间 sec_per_batch = float(duration / display_step) format_str = ('%s: step %d, loss = %.2f, Bpp = %.2f, MSE = %.2f, MS-SSIM = %.2f, PSNR = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self.step, loss, bpp, mse, ms_ssim, psnr, examples_per_sec, sec_per_batch)) if self.step % (display_step * 20) == 0: loss, bpp, bpp1, mse, mse1, original, compressed_img = run_values.results original *=255.0 compressed_img = np.array(np.clip((compressed_img + 0.5)*255.0, 0.0, 255.0), dtype=np.uint8) ms_ssim = msssim(compressed_img,original) psnr = 20 * math.log10( 255.0 / math.sqrt(mse)) format_str = ('%s: step %d, loss = %.2f, Bpp = %.2f, MSE = %.2f, MS-SSIM = %.2f, PSNR = %.2f') fin = open("rnn_256-512_0.01-0.08_loss.txt", 'a+') fin.write(format_str % (datetime.now(), self.step, loss, bpp, mse, ms_ssim, psnr)) fin.write("\n") with tf.train.MonitoredTrainingSession( hooks=[ tf.train.StopAtStepHook(last_step=args.last_step), tf.train.NanTensorHook(train_loss), LoggerHook()], checkpoint_dir=args.checkpoint_dir, save_checkpoint_secs=300, save_summaries_secs=60) as sess: while not sess.should_stop(): for i in range(args.iter): sess.run(train_ops[i])
def train(args): """Trains the model.""" if args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. with tf.device("/cpu:0"): train_files = glob.glob(args.train_glob) if not train_files: raise RuntimeError( "No training images found with glob '{}'.".format(args.train_glob)) train_dataset = tf.data.Dataset.from_tensor_slices(train_files) train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat() train_dataset = train_dataset.map( read_png, num_parallel_calls=args.preprocess_threads) train_dataset = train_dataset.map( lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3))) train_dataset = train_dataset.batch(args.batchsize) train_dataset = train_dataset.prefetch(32) num_pixels = args.batchsize * args.patchsize ** 2 # Get training patch from dataset. x = train_dataset.make_one_shot_iterator().get_next() # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) z = hyper_analysis_transform(abs(y)) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma = hyper_synthesis_transform(z_tilde) scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) x_tilde = synthesis_transform(y_tilde) # Total number of bits divided by number of pixels. train_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. train_mse *= 255 ** 2 # The rate-distortion cost. train_loss = args.lmbda * train_mse + train_bpp # Minimize loss and auxiliary loss, and execute update op. step = tf.train.create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) main_step = main_optimizer.minimize(train_loss, global_step=step) aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0]) train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) tf.summary.scalar("loss", train_loss) tf.summary.scalar("bpp", train_bpp) tf.summary.scalar("mse", train_mse) tf.summary.image("original", quantize_image(x)) tf.summary.image("reconstruction", quantize_image(x_tilde)) hooks = [ tf.train.StopAtStepHook(last_step=args.last_step), tf.train.NanTensorHook(train_loss), ] with tf.train.MonitoredTrainingSession( hooks=hooks, checkpoint_dir=args.checkpoint_dir, save_checkpoint_secs=300, save_summaries_secs=60) as sess: while not sess.should_stop(): sess.run(train_op)
def compress(args): """Compresses an image.""" with tf.device("/cpu:0"): test_files = glob.glob(args.compress_glob) if not test_files: raise RuntimeError( "No test images found with glob '{}'.".format(args.compress_glob)) for input_file in test_files: file = input_file.split("/") file = file[-1] file = file.split(".") output_file = file[-2] output_file = "./results/lmbda_0.1_1x1/test/" + output_file + ".tfci" tf.reset_default_graph() # Load input image and add batch dimension. x = read_png(input_file) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() context_prediction = MaskedConvolution2D() entropy_prediction = EntropyParameters() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma_ = hyper_synthesis_transform(z_hat) #*************************************************************************** # Instantiate context based rate prediction model y_bar = quantize(y) context = context_prediction(y_bar) prediction = tf.concat([context, sigma_], axis=3) mean, sigma = entropy_prediction(prediction) #*************************************************************************** mean = mean[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp(np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, dtype=tf.float32, mean=mean) side_string = entropy_bottleneck.compress(z) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) # y_hat = conditional_bottleneck._quantize(y, mode='symbols') string = conditional_bottleneck.compress(y) x_hat = synthesis_transform(y_bar) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) \ + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x *= 255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) tensors = [string, side_string, tf.shape(x)[1:-1], tf.shape(y)[1:-1], tf.shape(z)[1:-1]] # tensors = [side_string, tf.shape(z)[1:-1]] arrays = sess.run(tensors) # Write a binary file with the shape information and the compressed string. packed = tfc.PackedTensors() packed.pack(tensors, arrays) with open(output_file, "wb") as f: f.write(packed.string) # y_hat_, mean_, sigma_, x_shape_ = sess.run([y_bar, mean, sigma, tf.shape(x)[1:-1]]) # # fileobj = open(compressed_file_path, mode='wb') # arr = np.array([x_shape_[0], x_shape_[1]], dtype=np.uint16) # arr.tofile(fileobj) # fileobj.close() # bitout = arithmeticcoding.BitOutputStream(open(compressed_file_path, 'ab+')) # enc = arithmeticcoding.ArithmeticEncoder(bitout) # for ch_idx in range(y_hat_.shape[-1]): # for h_idx in range(y_hat_.shape[1]): # for w_idx in range(y_hat_.shape[2]): # mu_val = mean_[0, h_idx, w_idx, ch_idx] # sigma_val = abs(sigma_[0, h_idx, w_idx, ch_idx]) # freq = arithmeticcoding.ModelFrequencyTable(mu_val + 255, sigma_val) # print(y_hat_[0, h_idx, w_idx, ch_idx]) # symbol = y_hat_[0, h_idx, w_idx, ch_idx] + 255 # if symbol < 0 or symbol > 511: # print("symbol range error: " + str(symbol)) # enc.write(freq, symbol) # enc.write(freq, 512) # enc.finish() # bitout.close() # If requested, transform the quantized image back and measure performance. if args.verbose: eval_bpp, mse, psnr, msssim, num_pixels = sess.run( [eval_bpp, mse, psnr, msssim, num_pixels]) # The actual bits per pixel including overhead. bpp = len(packed.string) * 8 / num_pixels # bpp = (len(packed.string) + os.path.getsize(compressed_file_path)) * 8 / num_pixels a_mse.append(mse) a_msssim.append(msssim) a_psnr.append(psnr) a_msssim_dB.append(-10 * np.log10(1 - msssim)) a_eval_bpp.append(eval_bpp) a_bpp.append(bpp) log.logger.info("Image {}: mse {:0.8f} psnr {:0.8f} msssim {:0.8f} dB {:0.8f} eval_bpp {:0.8f} bpp {:0.8f}" .format(input_file, mse, psnr, msssim, -10 * np.log10(1 - msssim), eval_bpp, bpp))
def decompress(args): """Decompresses an image.""" with tf.device("/cpu:0"): binary_files = glob.glob(args.decompress_glob) if not binary_files: raise RuntimeError( "No test images found with glob '{}'.".format(args.decompress_glob)) for input_file in binary_files: file = input_file.split("/") file = file[-1] file = file.split(".") output_file = file[-2] output_file = "./results/lmbda_0.01_1x1/test/" + output_file + ".png" tf.reset_default_graph() # Read the shape information and compressed string from the binary file. string = tf.placeholder(tf.string, [1]) side_string = tf.placeholder(tf.string, [1]) x_shape = tf.placeholder(tf.int32, [2]) y_shape = tf.placeholder(tf.int32, [2]) z_shape = tf.placeholder(tf.int32, [2]) with open(input_file, "rb") as f: packed = tfc.PackedTensors(f.read()) tensors = [string, side_string, x_shape, y_shape, z_shape] arrays = packed.unpack(tensors) string, side_string, x_shape, y_shape, z_shape = arrays side_string, z_shape = arrays z_shape_ = z_shape # Instantiate model. synthesis_transform = SynthesisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32) entropy_prediction = EntropyParameters() context_prediction = MaskedConvolution2D() # Decompress and transform the image back. z_shape = tf.concat([z_shape, [args.num_filters]], axis=0) z_hat = entropy_bottleneck.decompress(side_string, z_shape, channels=args.num_filters) sigma_ = hyper_synthesis_transform(z_hat) #************************************************************************************************* y_bar = tf.placeholder(tf.float32, shape=(1, y_shape[0], y_shape[1], args.num_filters)) # y_bar = np.zeros((1, y_shape_[0], y_shape_[1], args.num_filters), dtype=np.int32) # context = MaskedConvolution2D(y_bar, (5, 5, 384)).output() context = context_prediction(y_bar) prediction = tf.concat([context, sigma_], axis=3) mean, sigma = entropy_prediction(prediction) # sigma = EntropyPrediction(prediction) #************************************************************************************************* mean = mean[:, :y_shape[0], :y_shape[1], :] sigma = sigma[:, :y_shape[0], :y_shape[1], :] scale_table = np.exp(np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, dtype=tf.float32, mean=mean) # conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, dtype=tf.float32) y_hat = conditional_bottleneck.decompress(string) x_hat = synthesis_transform(tf.round(y_bar)) # Remove batch dimension, and crop away any extraneous padding on the bottom or right boundaries. x_hat = x_hat[0, :x_shape[0], :x_shape[1], :] # Write reconstructed image out as a PNG file. op = write_png(output_file, x_hat) # Load the latest model checkpoint, and perform the above actions. with tf.Session() as sess: latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) y_bar_ = np.zeros((1, y_shape[0], y_shape[1], args.num_filters), dtype=np.float32) for h_idx in range(y_shape[0]): for w_idx in range(y_shape[1]): y_hat_, mean_ = sess.run([y_hat, mean], feed_dict={y_bar:np.around(y_bar_)}) y_bar_[:, h_idx, w_idx, :] = y_hat_[:, h_idx, w_idx, :] - mean_[:, h_idx, w_idx, :] sess.run(op, feed_dict={y_bar: y_bar_})
def compress(args): """Compresses an image.""" # Load input image and add batch dimension. image = imread(args.input_file).astype(np.float32) img = read_png(args.input_file) img = tf.expand_dims(img, 0) img.set_shape([1, img.shape[1], img.shape[2], 3]) x_shape = tf.shape(img) x = img - 0.5 # Transform and compress the image, then remove batch dimension. e = encoder(args.batchsize, height=image.shape[0], width=image.shape[1]) d = decoder(args.batchsize, height=image.shape[0], width=image.shape[1]) he = HyperEncoder(args.batchsize, height=image.shape[0] // 16, width=image.shape[1] // 16) hd = HyperDecoder(args.batchsize, height=image.shape[0] // 16, width=image.shape[1] // 16) #iteration # Transform and compress the image. encodes = [] hyper_encodes = [] strings = [] side_strings = [] MSE = [] PSNR = [] MSSSIM = [] eval_bpp = 0 x_hats = tf.zeros_like(x) + 0.5 num_pixels = tf.cast(tf.reduce_prod(tf.shape(img)[:-1]), dtype=tf.float32) comps = [] for i in range(args.iter): y = e.encode(x) encodes.append(y) y_shape = tf.shape(y) z = he.hyper_encode(abs(y)) hyper_encodes.append(z) entropy_bottleneck = tfc.EntropyBottleneck(name='entropy_iter'+ str(i)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hd.hyper_decode(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table,name='conditional'+ str(i)) side_string = entropy_bottleneck.compress(z) side_strings.append(side_string) string = conditional_bottleneck.compress(y) strings.append(string) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = d.decode(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] # Total number of bits divided by number of pixels. eval_bpp += ((tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels)) x = x - x_hat x_hats += x_hat # Bring both images back to 0..255 range. original = img * 255 compressdes = tf.clip_by_value(x_hats, 0, 1) compressdes = tf.round(compressdes * 255) comps.append(compressdes) mse = tf.reduce_mean(tf.squared_difference(original, compressdes)) psnr = tf.squeeze(tf.image.psnr(compressdes, original, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(compressdes, original, 255)) MSE.append(mse) PSNR.append(psnr) MSSSIM.append(msssim) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) bpp = 0 for i in range(args.iter): tensors = [strings[i], side_strings[i], tf.shape(img)[1:-1], tf.shape(encodes[i])[1:-1], tf.shape(hyper_encodes[i])[1:-1]] arrays = sess.run(tensors) # Write a binary file with the shape information and the compressed string. packed = tfc.PackedTensors() packed.pack(tensors, arrays) with open(args.output_file, "wb") as f: f.write(packed.string) # If requested, transform the quantized image back and measure performance. eval_bpps, mses, psnrs, msssims, num_pixelses = sess.run( [eval_bpp, MSE[i], PSNR[i], MSSSIM[i], num_pixels]) comp = comps[i].eval() # The actual bits per pixel including overhead. bpp += (len(packed.string) * 8 / num_pixelses) print("Mean squared error: {:0.4f}".format(mses)) print("PSNR (dB): {:0.2f}".format(psnrs)) print("Multiscale SSIM: {:0.4f}".format(msssims)) print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssims))) print("Information content in bpp: {:0.4f}".format(eval_bpps)) print("Actual bits per pixel: {:0.4f}".format(bpp)) fin = open("rnn_256-512_0.01-0.08_results.txt", 'a+') fin.write("Iter %d, %.8f, %.8f, %.8f, %.8f" % (i, mses, psnrs, msssims, bpp)) fin.write("\n") comp = np.squeeze(comp) imsave('compressed/recon_'+str(i) + '.png', comp)
def build_graph(args, x, training=True): """ Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3]. Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest. During training we sample from box-shaped posteriors; during compression this is approximated by rounding. """ # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) # y_tilde ~ q(y_tilde | y = g_a(x)) half = tf.constant(.5, dtype=y.dtype) if training: noise = tf.random.uniform(tf.shape(y), -half, half) y_tilde = y + noise else: # Approximately sample from q(y_tilde|x) by rounding. We can't be smart and do y_hat=floor(y + 0.5 - prior_mean) as # in Balle's model (ultimately implemented by conditional_bottleneck._quantize), because we don't have the prior # p(y_tilde | z_tilde) yet; in bb we have to sample z_tilde given y_tilde, whereas in BMSHJ2018, z_tilde is obtained # conditioned on x. y_tilde = tf.round(y) # z_tilde ~ q(z_tilde | h_a(\tilde y)) z_mean, z_logvar = tf.split(hyper_analysis_transform(y_tilde), num_or_size_splits=2, axis=-1) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean from utils import log_normal_pdf log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3)) z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive if training: sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5) if not training: # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) x_tilde = synthesis_transform(y_tilde) if not training: x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input return locals()
def test_compress(args): """Compresses an image.""" # Load input image and add batch dimension. fn = tf.placeholder(tf.string, []) x = read_png(fn) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) lmbda_level = tf.random_uniform([], minval=0, maxval=64, dtype=tf.int32) lmbda_onehot = tf.one_hot(tf.reshape(lmbda_level, [1]), depth=64) lmbda = 0.1 * tf.pow(2.0, tf.cast(lmbda_level, tf.float32) / 8.0 - 7.0) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters, lmbda_onehot) synthesis_transform = SynthesisTransform(args.num_filters, lmbda_onehot) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, lmbda_onehot) hyper_synthesis_transform = HyperSynthesisTransform( args.num_filters, lmbda_onehot) entropy_bottleneck = tfc.EntropyBottleneck() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum( tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x *= 255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) f = open("f6.csv", "w") print("level, fn, bpp, mse, np", file=f) for i in np.arange(0, 64): for filename in glob.glob("kodak/*.png"): v_lmbda_level, v_eval_bpp, v_mse, v_num_pixels = sess.run( [lmbda_level, eval_bpp, mse, num_pixels], feed_dict={ fn: filename, lmbda_level: i }) print( "%.2f, %s, %.4f, %.4f, %d" % (v_lmbda_level, filename, v_eval_bpp, v_mse, v_num_pixels), file=f) f.close()
def compress(args): """Compresses an image.""" # Load input image and add batch dimension. x = read_png(args.input_file) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x *= 255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) tensors = [string, side_string, tf.shape(x)[1:-1], tf.shape(y)[1:-1], tf.shape(z)[1:-1]] arrays = sess.run(tensors) # Write a binary file with the shape information and the compressed string. packed = tfc.PackedTensors() packed.pack(tensors, arrays) with open(args.output_file, "wb") as f: f.write(packed.string) # If requested, transform the quantized image back and measure performance. if args.verbose: eval_bpp, mse, psnr, msssim, num_pixels = sess.run( [eval_bpp, mse, psnr, msssim, num_pixels]) # The actual bits per pixel including overhead. bpp = len(packed.string) * 8 / num_pixels print("Mean squared error: {:0.4f}".format(mse)) print("PSNR (dB): {:0.2f}".format(psnr)) print("Multiscale SSIM: {:0.4f}".format(msssim)) print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim))) print("Information content in bpp: {:0.4f}".format(eval_bpp)) print("Actual bits per pixel: {:0.4f}".format(bpp))
def test_compress_mask(args): """Compresses an image.""" # Load input image and add batch dimension. fn = tf.placeholder(tf.string, []) x = read_png(fn) x = tf.expand_dims(x, 0) x.set_shape([1, None, None, 3]) x_shape = tf.shape(x) lmbda_level = tf.placeholder(tf.int32, []) lmbda_onehot = tf.one_hot(tf.reshape(lmbda_level,[1]), depth=8) lmbda = 0.1 * tf.pow(2.0, tf.cast(lmbda_level, tf.float32) - 6.0) actives = [tf.constant(256), tf.constant(256), tf.constant(256), tf.constant(3)] # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters, lmbda_onehot) synthesis_transform = SynthesisTransform(args.num_filters, lmbda_onehot, actives) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, lmbda_onehot) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, lmbda_onehot) entropy_bottleneck = tfc.EntropyBottleneck() # Transform and compress the image. y = analysis_transform(x) y_shape = tf.shape(y) z = hyper_analysis_transform(abs(y)) z_hat, z_likelihoods = entropy_bottleneck(z, training=False) sigma = hyper_synthesis_transform(z_hat) sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) side_string = entropy_bottleneck.compress(z) string = conditional_bottleneck.compress(y) # Transform the quantized image back (if requested). y_hat, y_likelihoods = conditional_bottleneck(y, training=False) x_hat = synthesis_transform(y_hat) x_hat = x_hat[:, :x_shape[1], :x_shape[2], :] num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32) # Total number of bits divided by number of pixels. eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Bring both images back to 0..255 range. x_im = x*255 x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat_im = tf.round(x_hat * 255) mse = tf.reduce_mean(tf.squared_difference(x_im, x_hat_im)) psnr = tf.squeeze(tf.image.psnr(x_hat_im, x_im, 255)) msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat_im, x_im, 255)) with tf.Session() as sess: # Load the latest model checkpoint, get the compressed string and the tensor # shapes. latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) tf.train.Saver().restore(sess, save_path=latest) tn_iter = { "trans": ["analysis", "hyper_analysis", "hyper_synthesis", "synthesis"], "layers": [4,3,3,3], "ops": ["fc_u"] } tn_names = ["{}_transform/layer_{}/{}/Softplus:0".format(tran,ln,op) \ for tran, layer in zip(tn_iter["trans"], tn_iter["layers"]) \ for ln in range(layer) for op in tn_iter["ops"]] print(tn_names) import pandas as pd df = pd.DataFrame() for i in np.arange(0,8): for name in tn_names: tn = tf.get_default_graph().get_tensor_by_name(name) tnv = sess.run(tf.reshape(tn, [256]), feed_dict={lmbda_level: i}) df1 = pd.DataFrame({"level":i, "name":name, "width":range(256), "value":tnv}) df = df.append(df1) df.to_csv("dynamic_mask.csv")