def call(self, inputs, training): """Pass a tensor through the bottleneck. Arguments: inputs: The tensor to be passed through the bottleneck. training: Boolean. If `True`, returns a differentiable approximation of the inputs, and their likelihoods under the modeled probability densities. If `False`, returns the quantized inputs and their likelihoods under the corresponding probability mass function. These quantities can't be used for training, as they are not differentiable, but represent actual compression more closely. Returns: values: `Tensor` with the same shape as `inputs` containing the perturbed or quantized input values. likelihood: `Tensor` with the same shape as `inputs` containing the likelihood of `values` under the modeled probability distributions. Raises: ValueError: if `inputs` has an integral or inconsistent `DType`, or inconsistent number of channels. """ inputs = tf.convert_to_tensor(inputs, dtype=self.dtype) if inputs.dtype.is_integer: raise ValueError( "{} can't take integer inputs.".format(type(self).__name__)) outputs = self._quantize(inputs, "noise" if training else "dequantize") assert outputs.dtype == self.dtype likelihood = self._likelihood(outputs) if self.likelihood_bound > 0: likelihood_bound = tf.constant(self.likelihood_bound, dtype=self.dtype) likelihood = math_ops.lower_bound(likelihood, likelihood_bound) if not tf.executing_eagerly(): outputs_shape, likelihood_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(outputs_shape) likelihood.set_shape(likelihood_shape) return outputs, likelihood
def _test_lower_bound(self, gradient): inputs = tf.placeholder(dtype=tf.float32) outputs = math_ops.lower_bound(inputs, 0, gradient=gradient) pgrads, = tf.gradients([outputs], [inputs], [tf.ones_like(inputs)]) ngrads, = tf.gradients([outputs], [inputs], [-tf.ones_like(inputs)]) inputs_feed = [-1, 1] outputs_expected = [0, 1] if gradient == "disconnected": pgrads_expected = [0, 1] ngrads_expected = [0, -1] elif gradient == "identity": pgrads_expected = [1, 1] ngrads_expected = [-1, -1] else: pgrads_expected = [0, 1] ngrads_expected = [-1, -1] with self.test_session() as sess: outputs, pgrads, ngrads = sess.run( [outputs, pgrads, ngrads], {inputs: inputs_feed}) self.assertAllEqual(outputs, outputs_expected) self.assertAllEqual(pgrads, pgrads_expected) self.assertAllEqual(ngrads, ngrads_expected)
def build(self, input_shape): """Builds the entropy model. This function precomputes the quantized CDF table based on the scale table. This can be done at graph construction time. Then, it creates the graph for computing the indexes into that table based on the scale tensor, and then uses this index tensor to determine the starting positions of the PMFs for each scale. Args: input_shape: Shape of the input tensor. Raises: ValueError: If `input_shape` doesn't specify number of input dimensions. """ input_shape = tf.TensorShape(input_shape) input_shape.assert_is_compatible_with(self.input_spec.shape) scale_table = tf.constant(self.scale_table, dtype=self.dtype) # Lower bound scales. We need to do this here, and not in __init__, because # the dtype may not yet be known there. if self.scale_bound is None: self._scale = math_ops.lower_bound(self._scale, scale_table[0]) elif self.scale_bound > 0: self._scale = math_ops.lower_bound(self._scale, self.scale_bound) multiplier = -self._standardized_quantile(self.tail_mass / 2) pmf_center = np.ceil(np.array(self.scale_table) * multiplier).astype(int) pmf_length = 2 * pmf_center + 1 max_length = np.max(pmf_length) # This assumes that the standardized cumulative has the property # 1 - c(x) = c(-x), which means we can compute differences equivalently in # the left or right tail of the cumulative. The point is to only compute # differences in the left tail. This increases numerical stability: c(x) is # 1 for large x, 0 for small x. Subtracting two numbers close to 0 can be # done with much higher precision than subtracting two numbers close to 1. samples = abs(np.arange(max_length, dtype=int) - pmf_center[:, None]) samples = tf.constant(samples, dtype=self.dtype) samples_scale = tf.expand_dims(scale_table, 1) upper = self._standardized_cumulative((.5 - samples) / samples_scale) lower = self._standardized_cumulative((-.5 - samples) / samples_scale) pmf = upper - lower # Compute out-of-range (tail) masses. tail_mass = 2 * lower[:, :1] def cdf_initializer(shape, dtype=None, partition_info=None): del partition_info # unused assert tuple(shape) == (len(pmf_length), max_length + 2) assert dtype == tf.int32 return self._pmf_to_cdf( pmf, tail_mass, tf.constant(pmf_length, dtype=tf.int32), max_length) quantized_cdf = self.add_weight( "quantized_cdf", shape=(len(pmf_length), max_length + 2), initializer=cdf_initializer, dtype=tf.int32, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) cdf_length = self.add_weight( "cdf_length", shape=(len(pmf_length),), initializer=tf.initializers.constant(pmf_length + 2), dtype=tf.int32, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) # Works around a weird TF issue with reading variables inside a loop. self._quantized_cdf = tf.identity(quantized_cdf) self._cdf_length = tf.identity(cdf_length) # Now, if they haven't been overridden, compute the indexes into the table # for each of the passed-in scales. if not hasattr(self, "_indexes"): # Prevent tensors from bouncing back and forth between host and GPU. with tf.device("/cpu:0"): fill = tf.constant( len(self.scale_table) - 1, dtype=tf.int32) initializer = tf.fill(tf.shape(self.scale), fill) def loop_body(indexes, scale): return indexes - tf.cast(self.scale <= scale, tf.int32) self._indexes = tf.foldr( loop_body, scale_table[:-1], initializer=initializer, back_prop=False, name="compute_indexes") self._offset = tf.constant(-pmf_center, dtype=tf.int32) super(SymmetricConditional, self).build(input_shape)
def __call__(self) -> tf.Tensor: """Computes and returns the non-negative value as a `tf.Tensor`.""" with self.name_scope: reparam_value = math_ops.lower_bound(self.variable, self._bound) return tf.math.square(reparam_value) - self._pedestal
def call(self, inputs, training): """Pass a tensor through the bottleneck. Args: inputs: The tensor to be passed through the bottleneck. training: Boolean. If `True`, returns a differentiable approximation of the inputs, and their likelihoods under the modeled probability densities. If `False`, returns the quantized inputs and their likelihoods under the corresponding probability mass function. These quantities can't be used for training, as they are not differentiable, but represent actual compression more closely. Returns: values: `Tensor` with the same shape as `inputs` containing the perturbed or quantized input values. likelihood: `Tensor` with the same shape as `inputs` containing the likelihood of `values` under the modeled probability distributions. Raises: ValueError: if `inputs` has different `dtype` or number of channels than a previous set of inputs the model was invoked with earlier. """ inputs = ops.convert_to_tensor(inputs) ndim = self.input_spec.ndim channel_axis = self._channel_axis(ndim) half = constant_op.constant(.5, dtype=self.dtype) # Convert to (channels, 1, batch) format by commuting channels to front # and then collapsing. order = list(range(ndim)) order.pop(channel_axis) order.insert(0, channel_axis) values = array_ops.transpose(inputs, order) shape = array_ops.shape(values) values = array_ops.reshape(values, (shape[0], 1, -1)) # Add noise or quantize. if training: noise = random_ops.random_uniform(array_ops.shape(values), -half, half) values = math_ops.add_n([values, noise]) elif self.optimize_integer_offset: values = math_ops.round(values - self._medians) + self._medians else: values = math_ops.round(values) # Evaluate densities. # We can use the special rule below to only compute differences in the left # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 # for large x, 0 for small x. Subtracting two numbers close to 0 can be done # with much higher precision than subtracting two numbers close to 1. lower = self._logits_cumulative(values - half, stop_gradient=False) upper = self._logits_cumulative(values + half, stop_gradient=False) # Flip signs if we can move more towards the left tail of the sigmoid. sign = -math_ops.sign(math_ops.add_n([lower, upper])) sign = array_ops.stop_gradient(sign) likelihood = abs( math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) if self.likelihood_bound > 0: likelihood_bound = constant_op.constant( self.likelihood_bound, dtype=self.dtype) likelihood = tfc_math_ops.lower_bound(likelihood, likelihood_bound) # Convert back to input tensor shape. order = list(range(1, ndim)) order.insert(channel_axis, 0) values = array_ops.reshape(values, shape) values = array_ops.transpose(values, order) likelihood = array_ops.reshape(likelihood, shape) likelihood = array_ops.transpose(likelihood, order) if not context.executing_eagerly(): values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) values.set_shape(values_shape) likelihood.set_shape(likelihood_shape) return values, likelihood
def reparam(var): var = math_ops.lower_bound(var, bound) var = tf.math.square(var) - pedestal return var
def build(self, input_shape): """Builds the entropy model. This function precomputes the quantized CDF table based on the scale table. This can be done at graph construction time. Then, it creates the graph for computing the indexes into that table based on the scale tensor, and then uses this index tensor to determine the starting positions of the PMFs for each scale. Arguments: input_shape: Shape of the input tensor. Raises: ValueError: If `input_shape` doesn't specify number of input dimensions. """ input_shape = tf.TensorShape(input_shape) input_shape.assert_is_compatible_with(self.input_spec.shape) scale_table = tf.constant(self.scale_table, dtype=self.dtype) # Lower bound scales. We need to do this here, and not in __init__, because # the dtype may not yet be known there. if self.scale_bound is None: self._scale = math_ops.lower_bound(self._scale, scale_table[0]) elif self.scale_bound > 0: self._scale = math_ops.lower_bound(self._scale, self.scale_bound) multiplier = -self._standardized_quantile(self.tail_mass / 2) pmf_center = np.ceil(np.array(self.scale_table) * multiplier).astype(int) pmf_length = 2 * pmf_center + 1 max_length = np.max(pmf_length) # This assumes that the standardized cumulative has the property # 1 - c(x) = c(-x), which means we can compute differences equivalently in # the left or right tail of the cumulative. The point is to only compute # differences in the left tail. This increases numerical stability: c(x) is # 1 for large x, 0 for small x. Subtracting two numbers close to 0 can be # done with much higher precision than subtracting two numbers close to 1. samples = abs(np.arange(max_length, dtype=int) - pmf_center[:, None]) samples = tf.constant(samples, dtype=self.dtype) samples_scale = tf.expand_dims(scale_table, 1) upper = self._standardized_cumulative((.5 - samples) / samples_scale) lower = self._standardized_cumulative((-.5 - samples) / samples_scale) pmf = upper - lower # Compute out-of-range (tail) masses. tail_mass = 2 * lower[:, :1] def cdf_initializer(shape, dtype=None, partition_info=None): del partition_info # unused assert tuple(shape) == (len(pmf_length), max_length + 2) assert dtype == tf.int32 return self._pmf_to_cdf( pmf, tail_mass, tf.constant(pmf_length, dtype=tf.int32), max_length) quantized_cdf = self.add_variable( "quantized_cdf", shape=(len(pmf_length), max_length + 2), initializer=cdf_initializer, dtype=tf.int32, trainable=False) cdf_length = self.add_variable( "cdf_length", shape=(len(pmf_length),), initializer=tf.initializers.constant(pmf_length + 2), dtype=tf.int32, trainable=False) # Works around a weird TF issue with reading variables inside a loop. self._quantized_cdf = tf.identity(quantized_cdf) self._cdf_length = tf.identity(cdf_length) # Now, if they haven't been overridden, compute the indexes into the table # for each of the passed-in scales. if not hasattr(self, "_indexes"): # Prevent tensors from bouncing back and forth between host and GPU. with tf.device("/cpu:0"): fill = tf.constant( len(self.scale_table) - 1, dtype=tf.int32) initializer = tf.fill(tf.shape(self.scale), fill) def loop_body(indexes, scale): return indexes - tf.cast(self.scale <= scale, tf.int32) self._indexes = tf.foldr( loop_body, scale_table[:-1], initializer=initializer, back_prop=False, name="compute_indexes") self._offset = tf.constant(-pmf_center, dtype=tf.int32) super(SymmetricConditional, self).build(input_shape)
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Initial values for optimization y_init = analysis_transform(x) z_init = hyper_analysis_transform(y_init) y = tf.placeholder('float32', y_init.shape) from utils import round_with_identity_STE as round_with_STE y_tilde = round_with_STE(y) x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior # # p(z_tilde) ("z_likelihoods") # z_tilde, z_likelihoods = entropy_bottleneck(z, training=training) z = tf.placeholder('float32', z_init.shape) z_tilde = round_with_STE(z) _ = entropy_bottleneck( z, training=False ) # dummy call to ensure entropy_bottleneck is properly built z_likelihoods = entropy_bottleneck._likelihood(z_tilde) # p(\tilde z) if entropy_bottleneck.likelihood_bound > 0: likelihood_bound = entropy_bottleneck.likelihood_bound z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # graph = build_graph(args, x, training=False) # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp' ] eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp] all_results_arrs = {key: [] for key in eval_fields } # append across all batches log_itv = 100 if save_opt_record or stop_early: log_itv = 10 rd_lr = 0.0001 rd_opt_its = 2000 from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur, z_cur = sess.run([y_init, z_init], feed_dict=x_feed_dict) # np arrays adam_optimizer = Adam(lr=rd_lr) if stop_early: obj_prev = np.inf y_prev, z_prev = None, None opt_record = { 'its': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z: z_cur, **x_feed_dict }) y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() print( 'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, obj, mse_, train_bpp_, psnr_)) if stop_early: if obj >= obj_prev: # no longer improving y_cur, z_cur = y_prev, z_prev break else: obj_prev = obj y_prev, z_prev = y_cur, z_cur opt_record['its'].append(it) opt_record['rd_loss'].append(obj) opt_record['rd_loss_after_rounding'].append(obj) print() y_tilde_cur = np.round( y_cur) # this is the latents we end up transmitting z_tilde_cur = np.round(z_cur) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_tilde_cur, z_tilde: z_tilde_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension # save RD evaluation results prefix = 'rd' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) if save_opt_record: # save optimization record prefix = 'opt' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **opt_record) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def test_lower_bound_invalid(self): with self.assertRaises(ValueError): math_ops.lower_bound(tf.zeros((1, 2)), 0, gradient="invalid")
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### from utils import log_normal_pdf from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3)) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) y_tilde = tf.round(y) x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # z_tilde ~ q(z_tilde | h_a(\tilde y)) z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde), num_or_size_splits=2, axis=-1) z_mean = tf.placeholder( 'float32', z_mean_init.shape) # initialize to inference network results z_logvar = tf.placeholder('float32', z_logvar_init.shape) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] bpp_back = tf.reduce_sum( -log_q_z_tilde, axis=axes_except_batch) / (np.log(2) * img_num_pixels) y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp - bpp_back # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) local_gradients = tf.gradients(train_bpp, [z_mean, z_logvar]) # Bring both images back to 0..255 range. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp', 'est_bpp_back' ] eval_tensors = [ mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back ] all_results_arrs = {key: [] for key in eval_fields } # append across all batches batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} z_mean_cur, z_logvar_cur = sess.run( [z_mean_init, z_logvar_init], feed_dict=x_feed_dict) # np arrays opt_obj_hist = [] opt_grad_hist = [] lr = 0.005 local_its = 1000 from adam import Adam np_adam_optimizer = Adam(lr=lr) for it in range(local_its): grads, obj = sess.run([local_gradients, train_bpp], feed_dict={ z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) z_mean_cur, z_logvar_cur = np_adam_optimizer.update( [z_mean_cur, z_logvar_cur], grads) if it % 100 == 0: print('negative local ELBO', obj) opt_obj_hist.append(obj) opt_grad_hist.append(np.mean(np.abs(grads))) print() # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file) if script_name != trained_script_name: save_file = 'rd-%s+%s-input=%s.npz' % (script_name, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Initial values for optimization y_init = analysis_transform(x) z_init = hyper_analysis_transform(y_init) y = tf.placeholder('float32', y_init.shape) T = tf.placeholder('float32', shape=[], name='temperature') y_floor = tf.floor(y) y_ceil = tf.ceil(y) y_bds = tf.stack([y_floor, y_ceil], axis=-1) epsilon = 1e-5 ry_logits = tf.stack( [ -tf.math.atanh( tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh( tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T ], axis=-1 ) # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0 ry = tf.nn.softmax(ry_logits, axis=-1) y_tilde = tf.reduce_sum(y_bds * ry, axis=-1) # inner product in last dim x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior # # p(z_tilde) ("z_likelihoods") # z_tilde, z_likelihoods = entropy_bottleneck(z, training=training) z = tf.placeholder('float32', z_init.shape) z_floor = tf.floor(z) z_ceil = tf.ceil(z) z_bds = tf.stack([z_floor, z_ceil], axis=-1) rz_logits = tf.stack([ -tf.math.atanh(tf.clip_by_value(z - z_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh(tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon)) / T ], axis=-1) # last dim are logits for DOWN or UP rz = tf.nn.softmax(rz_logits, axis=-1) z_tilde = tf.reduce_sum(z_bds * rz, axis=-1) # inner product in last dim # # We have to manually call entropy_bottleneck.build because we don't directly call entropy_bottleneck like we did # # with 'z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)' during training # # UPDATE: this doesn't quite work, as the resulting variables don't have the proper name scope (will just be named # # "matrix_0", "bias_0", etc., instead of "entropy_bottleneck/matrix_0", "entropy_bottleneck/bias_0" as would with # # calling entropy_bottleneck on tensor, which breaks model loading (will get "Key bias_0 not found in checkpoint.. # # tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due # # to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not # # altered the graph expected based on the checkpoint."). # entropy_bottleneck.build(z_tilde.shape) _ = entropy_bottleneck( z, training=False ) # dummy call to ensure entropy_bottleneck is properly built z_likelihoods = entropy_bottleneck._likelihood(z_tilde) # p(\tilde z) if entropy_bottleneck.likelihood_bound > 0: likelihood_bound = entropy_bottleneck.likelihood_bound z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp' ] eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp] all_results_arrs = {key: [] for key in eval_fields } # append across all batches log_itv = 100 if save_opt_record: log_itv = 10 rd_lr = 0.005 rd_opt_its = 2000 annealing_rate = 4e-3 T_ub = 0.2 def annealed_temperature(t, r, ub, lb=1e-8, backend=np): # Using the exp schedule from section 4.2 of Jang et. al., ICLR2017 if backend is None: return min(max(np.exp(-r * t), lb), ub) else: return backend.minimum( backend.maximum(backend.exp(-r * t), lb), ub) from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur, z_cur = sess.run([y_init, z_init], feed_dict=x_feed_dict) # np arrays adam_optimizer = Adam(lr=rd_lr) opt_record = { 'its': [], 'T': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): temperature = annealed_temperature(it, r=annealing_rate, ub=T_ub) grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z: z_cur, **x_feed_dict, T: temperature }) y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() if args.verbose: bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run( [train_bpp, psnr, rd_loss], feed_dict={ y_tilde: np.round(y_cur), z_tilde: np.round(z_cur), **x_feed_dict }) psnr_after_rounding = psnr_after_rounding.mean() print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_, rd_loss_after_rounding, bpp_after_rounding, psnr_after_rounding)) opt_record['rd_loss_after_rounding'].append( rd_loss_after_rounding) else: print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_)) opt_record['its'].append(it) opt_record['T'].append(temperature) opt_record['rd_loss'].append(obj) print() y_tilde_cur = np.round( y_cur) # this is the latents we end up transmitting z_tilde_cur = np.round(z_cur) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_tilde_cur, z_tilde: z_tilde_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension # save RD evaluation results prefix = 'rd' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) if save_opt_record: # save optimization record prefix = 'opt' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **opt_record) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def build_graph(args, x, training=True): """ Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3]. Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest. During training we sample from box-shaped posteriors; during compression this is approximated by rounding. """ # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) # z_tilde ~ q(z_tilde | x) = q(z_tilde | h_a(y)) z_mean, z_logvar = tf.split(hyper_analysis_transform(y), num_or_size_splits=2, axis=-1) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean from utils import log_normal_pdf log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3)) z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive if training: sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5) if not training: # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) # Note that at test/compression time, the resulting y_tilde doesn't simply # equal round(y); instead, the conditional_bottleneck does something # smarter and slightly more optimal: y_hat=floor(y + 0.5 - prior_mean), so # that the mean (mu) of the prior coincides with one of the quantization bins. y_tilde, y_likelihoods = conditional_bottleneck(y, training=training) x_tilde = synthesis_transform(y_tilde) if not training: x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input return locals()
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### from utils import log_normal_pdf from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3)) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Initial optimization (where we still have access to x) # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP. # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ... # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0. import tensorflow_probability as tfp T = tf.placeholder('float32', shape=[], name='temperature') y_init = analysis_transform(x) y = tf.placeholder('float32', y_init.shape) y_floor = tf.floor(y) y_ceil = tf.ceil(y) y_bds = tf.stack([y_floor, y_ceil], axis=-1) epsilon = 1e-5 logits = tf.stack( [ -tf.math.atanh( tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh( tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T ], axis=-1 ) # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0 rounding_dist = tfp.distributions.RelaxedOneHotCategorical( T, logits=logits) # technically we can use a different temperature here sample_concrete = rounding_dist.sample() y_tilde = tf.reduce_sum(y_bds * sample_concrete, axis=-1) # inner product in last dim x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # z_tilde ~ q(z_tilde | h_a(\tilde y)) z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde), num_or_size_splits=2, axis=-1) z_mean = tf.placeholder( 'float32', z_mean_init.shape) # initialize to inference network results z_logvar = tf.placeholder('float32', z_logvar_init.shape) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] batch_log_q_z_tilde = tf.reduce_sum(log_q_z_tilde, axis=axes_except_batch) bpp_back = -batch_log_q_z_tilde / (np.log(2) * img_num_pixels) batch_log_cond_p_y_tilde = tf.reduce_sum(tf.log(y_likelihoods), axis=axes_except_batch) y_bpp = -batch_log_cond_p_y_tilde / (np.log(2) * img_num_pixels) batch_log_p_z_tilde = tf.reduce_sum(tf.log(z_likelihoods), axis=axes_except_batch) z_bpp = -batch_log_p_z_tilde / (np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp - bpp_back # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z_mean, z_logvar]) r_gradients = tf.gradients(train_bpp, [z_mean, z_logvar]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp', 'est_bpp_back' ] eval_tensors = [ mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back ] all_results_arrs = {key: [] for key in eval_fields } # append across all batches import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt log_itv = 100 rd_lr = 0.005 # rd_opt_its = args.sga_its rd_opt_its = 2000 annealing_scheme = 'exp0' annealing_rate = args.annealing_rate # default annealing_rate = 1e-3 t0 = args.t0 # default t0 = 700 T_ub = 0.5 # max/initial temperature from utils import annealed_temperature r_lr = 0.003 r_opt_its = 2000 from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur = sess.run(y_init, feed_dict=x_feed_dict) # np arrays z_mean_cur, z_logvar_cur = sess.run( [z_mean_init, z_logvar_init], feed_dict={y_tilde: y_cur}) rd_loss_hist = [] adam_optimizer = Adam(lr=rd_lr) opt_record = { 'its': [], 'T': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): temperature = annealed_temperature(it, r=annealing_rate, ub=T_ub, scheme=annealing_scheme, t0=t0) grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict, T: temperature }) y_cur, z_mean_cur, z_logvar_cur = adam_optimizer.update( [y_cur, z_mean_cur, z_logvar_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() if args.verbose: bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run( [train_bpp, psnr, rd_loss], feed_dict={ y_tilde: np.round(y_cur), z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) psnr_after_rounding = psnr_after_rounding.mean() print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_, rd_loss_after_rounding, bpp_after_rounding, psnr_after_rounding)) else: print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_)) rd_loss_hist.append(obj) print() # 2. Fix y_tilde, perform rate optimization w.r.t. z_mean and z_logvar. y_tilde_cur = np.round( y_cur) # this is the latents we end up transmitting # rate_feed_dict = {y_tilde: y_tilde_cur, **x_feed_dict} rate_feed_dict = {y_tilde: y_tilde_cur} np.random.seed(seed) tf.set_random_seed(seed) print('----Rate Optimization----') # Reinitialize based on the value of y_tilde z_mean_cur, z_logvar_cur = sess.run( [z_mean_init, z_logvar_init], feed_dict=rate_feed_dict) # np arrays r_loss_hist = [] # rate_grad_hist = [] adam_optimizer = Adam(lr=r_lr) for it in range(r_opt_its): grads, obj = sess.run( [r_gradients, train_bpp], feed_dict={ z_mean: z_mean_cur, z_logvar: z_logvar_cur, **rate_feed_dict }) z_mean_cur, z_logvar_cur = adam_optimizer.update( [z_mean_cur, z_logvar_cur], grads) if it % log_itv == 0 or it + 1 == r_opt_its: print('it=', it, '\trate=', obj) r_loss_hist.append(obj) # rate_grad_hist.append(np.mean(np.abs(grads))) print() # fig, axes = plt.subplots(nrows=2, sharex=True) # axes[0].plot(rd_loss_hist) # axes[0].set_ylabel('RD loss') # axes[1].plot(r_loss_hist) # axes[1].set_ylabel('Rate loss') # axes[1].set_xlabel('SGD iterations') # plt.savefig('plots/local_q_opt_hist-%s-input=%s-b=%d.png' % # (args.runname, os.path.basename(args.input_file), batch_idx)) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_tilde_cur, z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file) if script_name != trained_script_name: save_file = 'rd-%s-lmbda=%g+%s-input=%s.npz' % ( script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def train(args): tf.reset_default_graph() num_channels = args.num_channels dims = args.dims init_scale = args.init_scale model = BMSHJ2018Prior(num_channels, dims=dims, init_scale=init_scale) # if hasattr(args, 'runname') and args['runname']: # runname = args.runname # else: # import time, datetime # ts = time.time() # runname = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H:%M:%S') runname = get_runname(vars(args)) data = np.load(args.data_path) lr = args.lr its = args.its tol = args.tol logging_freq = args.logging_freq plot = args.plot if plot: import matplotlib.pyplot as plt checkpoint_dir = args.checkpoint_dir import os save_dir_name = runname save_dir = os.path.join(checkpoint_dir, save_dir_name) if not os.path.exists(save_dir): os.makedirs(save_dir) model_name = os.path.join(save_dir, 'prior_model') import json with open(os.path.join(save_dir, 'args.json'), 'w') as f: # will overwrite existing json.dump(vars(args), f, indent=4, sort_keys=True) assert not tf.executing_eagerly() X = tf.placeholder(data.dtype, [None, num_channels]) pdf_lower_bound = 1e-10 # pdf = model.pdf(X, False) # [cdf, pdf] = model.cdf_pdf(X, False) # pdf = math_ops.lower_bound(pdf, pdf_lower_bound) # log_likelihood = tf.math.log(pdf) # log_likelihood = model.logpdf(X, False) pdf = model.pdf(X, stop_gradient=False) pdf = math_ops.lower_bound(pdf, pdf_lower_bound) log_likelihood = tf.math.log(pdf) print(log_likelihood) loss = - tf.reduce_mean(log_likelihood) optimizer = tf.train.AdamOptimizer(lr) # optimizer = tf.train.AdadeltaOptimizer() train_step = optimizer.minimize(loss) record = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) prev_loss = float('inf') for it in range(its): sess.run(train_step, feed_dict={X: data}) loss_ = sess.run(loss, feed_dict={X: data}) loss_ = float(loss_) if abs(prev_loss - loss_) / abs(loss_) < tol: break if it % logging_freq == 0 or it + 1 == its: print('it=%d,\t\tloss=%g' % (it, loss_)) record.append(dict(it=it, loss=loss_)) if plot: # plot p(x) xlim = [-5, 5] xs = np.linspace(*xlim) # figsize=None figsize = (12, 8) plt.figure(figsize=figsize) xs_feed = np.tile(xs[..., None], num_channels) # len(xs) by num_channels q_xs = sess.run(pdf, feed_dict={X: xs_feed}) h, v = 2, 4 for k in range(h * v): plt.subplot(h, v, k + 1) plt.plot(xs, q_xs[:, k], label='$q(x)$') bins = 31 plt.hist(data[:, k].ravel(), bins=bins, density=True, alpha=0.4, label='$\hat q(z)$') plt.xlim(xlim) plt.title('channel %d, it %d' % (k, it)) # plt.ylim([0, 2]) plt.legend() plt.tight_layout() plt.savefig(os.path.join(save_dir, runname + '_it=%d.png' % it)) # plt.show() model.save_weights(model_name) with open(os.path.join(save_dir, 'record.json'), 'w') as f: # will overwrite existing json.dump(record, f, indent=4, sort_keys=True)
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format. or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_shape = (None, *X.shape[1:]) x_ph = x = tf.placeholder('float32', x_shape) # keep a reference around for feed_dict #### BEGIN build compression graph #### # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) entropy_bottleneck = tfc.EntropyBottleneck() # Initial values for optimization y_init = analysis_transform(x) z_init = hyper_analysis_transform(y_init) # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP. # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ... # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0. import tensorflow_probability as tfp T = tf.placeholder('float32', shape=[], name='temperature') z = tf.placeholder( 'float32', z_init.shape ) # interface ("proxy") variable for SGA (to be annealed to int) z_floor = tf.floor(z) z_ceil = tf.ceil(z) z_bds = tf.stack([z_floor, z_ceil], axis=-1) rz_logits = tf.stack( [ -tf.math.atanh( tf.clip_by_value(z - z_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh( tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon)) / T ], axis=-1 ) # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0 rz_dist = tfp.distributions.RelaxedOneHotCategorical( T, logits=rz_logits ) # technically we can use a different temperature here rz_sample = rz_dist.sample() z_tilde = tf.reduce_sum(z_bds * rz_sample, axis=-1) # inner product in last dim _ = entropy_bottleneck( z, training=False ) # dummy call to ensure entropy_bottleneck is properly built z_likelihoods = entropy_bottleneck._likelihood(z_tilde) # p(\tilde z) if entropy_bottleneck.likelihood_bound > 0: likelihood_bound = entropy_bottleneck.likelihood_bound z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound) # compute parameters of conditional prior p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # set up SGA for low-level latents y = tf.placeholder( 'float32', y_init.shape ) # interface ("proxy") variable for SGA (to be annealed to int) y_floor = tf.floor(y) y_ceil = tf.ceil(y) y_bds = tf.stack([y_floor, y_ceil], axis=-1) ry_logits = tf.stack([ -tf.math.atanh(tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh(tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T ], axis=-1) # last dim are logits for DOWN or UP ry_dist = tfp.distributions.RelaxedOneHotCategorical( T, logits=ry_logits ) # technically we can use a different temperature here ry_sample = ry_dist.sample() y_tilde = tf.reduce_sum(y_bds * ry_sample, axis=-1) # inner product in last dim x_tilde = synthesis_transform(y_tilde) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # graph = build_graph(args, x, training=False) # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 if save_reconstruction: x_tilde_float = x_tilde x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp' ] eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp] all_results_arrs = {key: [] for key in eval_fields } # append across all batches log_itv = 100 if save_opt_record: log_itv = 10 rd_lr = 0.005 # rd_opt_its = args.sga_its rd_opt_its = 2000 annealing_scheme = 'exp0' annealing_rate = args.annealing_rate # default annealing_rate = 1e-3 t0 = args.t0 # default t0 = 700 T_ub = 0.5 # max/initial temperature from utils import annealed_temperature from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur, z_cur = sess.run([y_init, z_init], feed_dict=x_feed_dict) # np arrays adam_optimizer = Adam(lr=rd_lr) opt_record = { 'its': [], 'T': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): temperature = annealed_temperature(it, r=annealing_rate, ub=T_ub, scheme=annealing_scheme, t0=t0) grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z: z_cur, **x_feed_dict, T: temperature }) y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() if args.verbose: bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run( [train_bpp, psnr, rd_loss], feed_dict={ y_tilde: np.round(y_cur), z_tilde: np.round(z_cur), **x_feed_dict }) psnr_after_rounding = psnr_after_rounding.mean() print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_, rd_loss_after_rounding, bpp_after_rounding, psnr_after_rounding)) opt_record['rd_loss_after_rounding'].append( rd_loss_after_rounding) else: print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_)) opt_record['its'].append(it) opt_record['T'].append(temperature) opt_record['rd_loss'].append(obj) print() y_tilde_cur = np.round( y_cur) # this is the latents we end up transmitting z_tilde_cur = np.round(z_cur) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_tilde_cur, z_tilde: z_tilde_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension # save RD evaluation results prefix = 'rd' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) if save_opt_record: # save optimization record prefix = 'opt' save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % ( prefix, script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **opt_record) if save_reconstruction: assert num_images == 1 prefix = 'recon' save_file = '%s-%s-input=%s.png' % (prefix, args.runname, input_file) if script_name != trained_script_name: save_file = '%s-%s-lmbda=%g-rd_opt_its=%d+%s-input=%s.png' % ( prefix, script_name, args.lmbda, rd_opt_its, args.runname, input_file) # Write reconstructed image out as a PNG file. save_file = os.path.join(args.results_dir, save_file) print("Saving image reconstruction to ", save_file) save_png_op = write_png(save_file, x_tilde_float[0]) sess.run(save_png_op, feed_dict={y_tilde: y_tilde_cur}) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))