def decompress(args):
    """Decompresses an image."""
    # Adapted from https://github.com/tensorflow/compression/blob/master/examples/bmshj2018.py
    # Read the shape information and compressed string from the binary file.
    string = tf.placeholder(tf.string, [1])
    side_string = tf.placeholder(tf.string, [1])
    x_shape = tf.placeholder(tf.int32, [2])
    y_shape = tf.placeholder(tf.int32, [2])
    z_shape = tf.placeholder(tf.int32, [2])
    with open(args.input_file, "rb") as f:
        packed = tfc.PackedTensors(f.read())
    tensors = [string, side_string, x_shape, y_shape, z_shape]
    arrays = packed.unpack(tensors)

    # Instantiate model. TODO: automate this with build_graph
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32)

    # Decompress and transform the image back.
    z_shape = tf.concat([z_shape, [args.num_filters]], axis=0)
    z_hat = entropy_bottleneck.decompress(side_string,
                                          z_shape,
                                          channels=args.num_filters)

    mu, sigma = tf.split(hyper_synthesis_transform(z_hat),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    training = False
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        mu = mu[:, :y_shape[0], :y_shape[1], :]
        sigma = sigma[:, :y_shape[0], :y_shape[1], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu,
                                                     dtype=tf.float32)
    y_hat = conditional_bottleneck.decompress(string)
    x_hat = synthesis_transform(y_hat)

    # Remove batch dimension, and crop away any extraneous padding on the bottom
    # or right boundaries.
    x_hat = x_hat[0, :x_shape[0], :x_shape[1], :]

    # Write reconstructed image out as a PNG file.
    op = write_png(args.output_file, x_hat)

    # Load the latest model checkpoint, and perform the above actions.
    with tf.Session() as sess:
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        sess.run(op, feed_dict=dict(zip(tensors, arrays)))
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)  # y = g_a(x)
    z = hyper_analysis_transform(y)  # z = h_a(y)

    # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # p(z_tilde) ("z_likelihoods")
    z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of
    # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5,  0.5)
    y_tilde, y_likelihoods = conditional_bottleneck(y, training=training)
    x_tilde = synthesis_transform(y_tilde)

    if not training:
        side_string = entropy_bottleneck.compress(z)
        string = conditional_bottleneck.compress(y)
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # y_tilde ~ q(y_tilde | y = g_a(x))
    half = tf.constant(.5, dtype=y.dtype)
    if training:
        noise = tf.random.uniform(tf.shape(y), -half, half)
        y_tilde = y + noise
    else:
        # Approximately sample from q(y_tilde|x) by rounding. We can't be smart and do y_hat=floor(y + 0.5 - prior_mean) as
        # in Balle's model (ultimately implemented by conditional_bottleneck._quantize), because we don't have the prior
        # p(y_tilde | z_tilde) yet; in bb we have to sample z_tilde given y_tilde, whereas in BMSHJ2018, z_tilde is obtained
        # conditioned on x.
        y_tilde = tf.round(y)

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y_tilde),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
Example #4
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    y = tf.placeholder('float32', y_init.shape)
    y_tilde = y + tf.random.uniform(tf.shape(y), -0.5, 0.5)

    z = tf.placeholder('float32', z_init.shape)
    # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # p(z_tilde) ("z_likelihoods")
    z_tilde, z_likelihoods = entropy_bottleneck(z, training=True)
    z_hat = entropy_bottleneck._quantize(
        z, 'dequantize')  # rounded (with median centering)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    y_hat = conditional_bottleneck._quantize(
        y, 'dequantize')  # rounded (with mean centering)

    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        rd_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            y_hat_, z_hat_ = sess.run([y_hat, z_hat],
                                                      feed_dict={
                                                          y: y_cur,
                                                          z: z_cur
                                                      })
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: y_hat_,
                                    z_tilde: z_hat_,
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, obj, mse_, train_bpp_, psnr_,
                                   rd_loss_after_rounding, bpp_after_rounding,
                                   psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)

                        else:
                            print(
                                'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, obj, mse_, train_bpp_, psnr_))
                        opt_record['its'].append(it)
                        opt_record['rd_loss'].append(obj)

                print()

                # this is the latents we end up transmitting
                y_hat_, z_hat_ = sess.run([y_hat, z_hat],
                                          feed_dict={
                                              y: y_cur,
                                              z: z_cur
                                          })

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_hat_,
                                         z_tilde: z_hat_,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
Example #5
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    y = tf.placeholder('float32', y_init.shape)
    T = tf.placeholder('float32', shape=[], name='temperature')
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    epsilon = 1e-5
    ry_logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    ry = tf.nn.softmax(ry_logits, axis=-1)
    y_tilde = tf.reduce_sum(y_bds * ry, axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # # p(z_tilde) ("z_likelihoods")
    # z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)
    z = tf.placeholder('float32', z_init.shape)
    z_floor = tf.floor(z)
    z_ceil = tf.ceil(z)
    z_bds = tf.stack([z_floor, z_ceil], axis=-1)
    rz_logits = tf.stack([
        -tf.math.atanh(tf.clip_by_value(z - z_floor, -1 + epsilon,
                                        1 - epsilon)) / T,
        -tf.math.atanh(tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon))
        / T
    ],
                         axis=-1)  # last dim are logits for DOWN or UP
    rz = tf.nn.softmax(rz_logits, axis=-1)
    z_tilde = tf.reduce_sum(z_bds * rz, axis=-1)  # inner product in last dim

    # # We have to manually call entropy_bottleneck.build because we don't directly call entropy_bottleneck like we did
    # # with 'z_tilde, z_likelihoods = entropy_bottleneck(z, training=training)' during training
    # # UPDATE: this doesn't quite work, as the resulting variables don't have the proper name scope (will just be named
    # # "matrix_0", "bias_0", etc., instead of "entropy_bottleneck/matrix_0", "entropy_bottleneck/bias_0" as would with
    # # calling entropy_bottleneck on tensor, which breaks model loading (will get "Key bias_0 not found in checkpoint..
    # # tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due
    # # to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not
    # # altered the graph expected based on the checkpoint.").
    # entropy_bottleneck.build(z_tilde.shape)
    _ = entropy_bottleneck(
        z, training=False
    )  # dummy call to ensure entropy_bottleneck is properly built
    z_likelihoods = entropy_bottleneck._likelihood(z_tilde)  # p(\tilde z)
    if entropy_bottleneck.likelihood_bound > 0:
        likelihood_bound = entropy_bottleneck.likelihood_bound
        z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        rd_opt_its = 2000
        annealing_rate = 4e-3
        T_ub = 0.2

        def annealed_temperature(t, r, ub, lb=1e-8, backend=np):
            # Using the exp schedule from section 4.2 of Jang et. al., ICLR2017
            if backend is None:
                return min(max(np.exp(-r * t), lb), ub)
            else:
                return backend.minimum(
                    backend.maximum(backend.exp(-r * t), lb), ub)

        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_tilde: np.round(z_cur),
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                        opt_record['its'].append(it)
                        opt_record['T'].append(temperature)
                        opt_record['rd_loss'].append(obj)

                print()

                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                z_tilde_cur = np.round(z_cur)

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_tilde: z_tilde_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
Example #6
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    from utils import log_normal_pdf
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3))

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)
    y_tilde = tf.round(y)

    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde),
                                          num_or_size_splits=2,
                                          axis=-1)
    z_mean = tf.placeholder(
        'float32',
        z_mean_init.shape)  # initialize to inference network results
    z_logvar = tf.placeholder('float32', z_logvar_init.shape)

    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    bpp_back = tf.reduce_sum(
        -log_q_z_tilde, axis=axes_except_batch) / (np.log(2) * img_num_pixels)
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp - bpp_back  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    local_gradients = tf.gradients(train_bpp, [z_mean, z_logvar])

    # Bring both images back to 0..255 range.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp', 'est_bpp_back'
        ]
        eval_tensors = [
            mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back
        ]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init],
                    feed_dict=x_feed_dict)  # np arrays

                opt_obj_hist = []
                opt_grad_hist = []
                lr = 0.005
                local_its = 1000
                from adam import Adam
                np_adam_optimizer = Adam(lr=lr)
                for it in range(local_its):
                    grads, obj = sess.run([local_gradients, train_bpp],
                                          feed_dict={
                                              z_mean: z_mean_cur,
                                              z_logvar: z_logvar_cur,
                                              **x_feed_dict
                                          })
                    z_mean_cur, z_logvar_cur = np_adam_optimizer.update(
                        [z_mean_cur, z_logvar_cur], grads)
                    if it % 100 == 0:
                        print('negative local ELBO', obj)
                    opt_obj_hist.append(obj)
                    opt_grad_hist.append(np.mean(np.abs(grads)))
                print()

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         z_mean: z_mean_cur,
                                         z_logvar: z_logvar_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension
        save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file)
        if script_name != trained_script_name:
            save_file = 'rd-%s+%s-input=%s.npz' % (script_name, args.runname,
                                                   input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
Example #7
0
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # z_tilde ~ q(z_tilde | x) = q(z_tilde | h_a(y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean
    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of
    # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    # Note that at test/compression time, the resulting y_tilde doesn't simply
    # equal round(y); instead, the conditional_bottleneck does something
    # smarter and slightly more optimal: y_hat=floor(y + 0.5 - prior_mean), so
    # that the mean (mu) of the prior coincides with one of the quantization bins.
    y_tilde, y_likelihoods = conditional_bottleneck(y, training=training)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    from utils import log_normal_pdf
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3))

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial optimization (where we still have access to x)
    # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot
    # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP.
    # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with
    # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ...
    # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0.
    import tensorflow_probability as tfp
    T = tf.placeholder('float32', shape=[], name='temperature')
    y_init = analysis_transform(x)
    y = tf.placeholder('float32', y_init.shape)
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    epsilon = 1e-5
    logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    rounding_dist = tfp.distributions.RelaxedOneHotCategorical(
        T,
        logits=logits)  # technically we can use a different temperature here
    sample_concrete = rounding_dist.sample()
    y_tilde = tf.reduce_sum(y_bds * sample_concrete,
                            axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde),
                                          num_or_size_splits=2,
                                          axis=-1)
    z_mean = tf.placeholder(
        'float32',
        z_mean_init.shape)  # initialize to inference network results
    z_logvar = tf.placeholder('float32', z_logvar_init.shape)

    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    batch_log_q_z_tilde = tf.reduce_sum(log_q_z_tilde, axis=axes_except_batch)
    bpp_back = -batch_log_q_z_tilde / (np.log(2) * img_num_pixels)
    batch_log_cond_p_y_tilde = tf.reduce_sum(tf.log(y_likelihoods),
                                             axis=axes_except_batch)
    y_bpp = -batch_log_cond_p_y_tilde / (np.log(2) * img_num_pixels)
    batch_log_p_z_tilde = tf.reduce_sum(tf.log(z_likelihoods),
                                        axis=axes_except_batch)
    z_bpp = -batch_log_p_z_tilde / (np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp - bpp_back  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z_mean, z_logvar])
    r_gradients = tf.gradients(train_bpp, [z_mean, z_logvar])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp', 'est_bpp_back'
        ]
        eval_tensors = [
            mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back
        ]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        log_itv = 100
        rd_lr = 0.005
        # rd_opt_its = args.sga_its
        rd_opt_its = 2000
        annealing_scheme = 'exp0'
        annealing_rate = args.annealing_rate  # default annealing_rate = 1e-3
        t0 = args.t0  # default t0 = 700
        T_ub = 0.5  # max/initial temperature
        from utils import annealed_temperature
        r_lr = 0.003
        r_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur = sess.run(y_init, feed_dict=x_feed_dict)  # np arrays
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init], feed_dict={y_tilde: y_cur})
                rd_loss_hist = []
                adam_optimizer = Adam(lr=rd_lr)

                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub,
                                                       scheme=annealing_scheme,
                                                       t0=t0)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z_mean: z_mean_cur,
                            z_logvar: z_logvar_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_mean_cur, z_logvar_cur = adam_optimizer.update(
                        [y_cur, z_mean_cur, z_logvar_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_mean: z_mean_cur,
                                    z_logvar: z_logvar_cur,
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                    rd_loss_hist.append(obj)
                print()

                # 2. Fix y_tilde, perform rate optimization w.r.t. z_mean and z_logvar.
                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                # rate_feed_dict = {y_tilde: y_tilde_cur, **x_feed_dict}
                rate_feed_dict = {y_tilde: y_tilde_cur}
                np.random.seed(seed)
                tf.set_random_seed(seed)
                print('----Rate Optimization----')
                # Reinitialize based on the value of y_tilde
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init],
                    feed_dict=rate_feed_dict)  # np arrays

                r_loss_hist = []
                # rate_grad_hist = []

                adam_optimizer = Adam(lr=r_lr)
                for it in range(r_opt_its):
                    grads, obj = sess.run(
                        [r_gradients, train_bpp],
                        feed_dict={
                            z_mean: z_mean_cur,
                            z_logvar: z_logvar_cur,
                            **rate_feed_dict
                        })
                    z_mean_cur, z_logvar_cur = adam_optimizer.update(
                        [z_mean_cur, z_logvar_cur], grads)
                    if it % log_itv == 0 or it + 1 == r_opt_its:
                        print('it=', it, '\trate=', obj)
                    r_loss_hist.append(obj)
                    # rate_grad_hist.append(np.mean(np.abs(grads)))
                print()

                # fig, axes = plt.subplots(nrows=2, sharex=True)
                # axes[0].plot(rd_loss_hist)
                # axes[0].set_ylabel('RD loss')
                # axes[1].plot(r_loss_hist)
                # axes[1].set_ylabel('Rate loss')
                # axes[1].set_xlabel('SGD iterations')
                # plt.savefig('plots/local_q_opt_hist-%s-input=%s-b=%d.png' %
                #             (args.runname, os.path.basename(args.input_file), batch_idx))

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_mean: z_mean_cur,
                                         z_logvar: z_logvar_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension
        save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file)
        if script_name != trained_script_name:
            save_file = 'rd-%s-lmbda=%g+%s-input=%s.npz' % (
                script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
Example #9
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format. or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_shape = (None, *X.shape[1:])
    x_ph = x = tf.placeholder('float32',
                              x_shape)  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot
    # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP.
    # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with
    # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ...
    # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0.
    import tensorflow_probability as tfp
    T = tf.placeholder('float32', shape=[], name='temperature')

    z = tf.placeholder(
        'float32', z_init.shape
    )  # interface ("proxy") variable for SGA (to be annealed to int)
    z_floor = tf.floor(z)
    z_ceil = tf.ceil(z)
    z_bds = tf.stack([z_floor, z_ceil], axis=-1)
    rz_logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(z - z_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(z_ceil - z, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    rz_dist = tfp.distributions.RelaxedOneHotCategorical(
        T, logits=rz_logits
    )  # technically we can use a different temperature here
    rz_sample = rz_dist.sample()
    z_tilde = tf.reduce_sum(z_bds * rz_sample,
                            axis=-1)  # inner product in last dim

    _ = entropy_bottleneck(
        z, training=False
    )  # dummy call to ensure entropy_bottleneck is properly built
    z_likelihoods = entropy_bottleneck._likelihood(z_tilde)  # p(\tilde z)
    if entropy_bottleneck.likelihood_bound > 0:
        likelihood_bound = entropy_bottleneck.likelihood_bound
        z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_bound)

    # compute parameters of conditional prior p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # set up SGA for low-level latents
    y = tf.placeholder(
        'float32', y_init.shape
    )  # interface ("proxy") variable for SGA (to be annealed to int)
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    ry_logits = tf.stack([
        -tf.math.atanh(tf.clip_by_value(y - y_floor, -1 + epsilon,
                                        1 - epsilon)) / T,
        -tf.math.atanh(tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon))
        / T
    ],
                         axis=-1)  # last dim are logits for DOWN or UP
    ry_dist = tfp.distributions.RelaxedOneHotCategorical(
        T, logits=ry_logits
    )  # technically we can use a different temperature here
    ry_sample = ry_dist.sample()
    y_tilde = tf.reduce_sum(y_bds * ry_sample,
                            axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # graph = build_graph(args, x, training=False)

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    if save_reconstruction:
        x_tilde_float = x_tilde
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        # rd_opt_its = args.sga_its
        rd_opt_its = 2000
        annealing_scheme = 'exp0'
        annealing_rate = args.annealing_rate  # default annealing_rate = 1e-3
        t0 = args.t0  # default t0 = 700
        T_ub = 0.5  # max/initial temperature
        from utils import annealed_temperature
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub,
                                                       scheme=annealing_scheme,
                                                       t0=t0)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_tilde: np.round(z_cur),
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                        opt_record['its'].append(it)
                        opt_record['T'].append(temperature)
                        opt_record['rd_loss'].append(obj)

                print()

                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                z_tilde_cur = np.round(z_cur)

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_tilde: z_tilde_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        if save_reconstruction:
            assert num_images == 1
            prefix = 'recon'
            save_file = '%s-%s-input=%s.png' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g-rd_opt_its=%d+%s-input=%s.png' % (
                    prefix, script_name, args.lmbda, rd_opt_its, args.runname,
                    input_file)
            # Write reconstructed image out as a PNG file.
            save_file = os.path.join(args.results_dir, save_file)
            print("Saving image reconstruction to ", save_file)
            save_png_op = write_png(save_file, x_tilde_float[0])
            sess.run(save_png_op, feed_dict={y_tilde: y_tilde_cur})

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))