Esempio n. 1
0
def vae_cross_ent_loss(model, x, beta=1.):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    logpx_z = cross_ent_loss(x_logit, x)
    logpz = utils.log_normal_pdf(z, 0., 0.)
    logqz_x = utils.log_normal_pdf(z, mean, logvar)
    kld = logqz_x - logpz
    return -tf.reduce_mean(logpx_z - beta*kld)
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # y_tilde ~ q(y_tilde | y = g_a(x))
    half = tf.constant(.5, dtype=y.dtype)
    if training:
        noise = tf.random.uniform(tf.shape(y), -half, half)
        y_tilde = y + noise
    else:
        # Approximately sample from q(y_tilde|x) by rounding. We can't be smart and do y_hat=floor(y + 0.5 - prior_mean) as
        # in Balle's model (ultimately implemented by conditional_bottleneck._quantize), because we don't have the prior
        # p(y_tilde | z_tilde) yet; in bb we have to sample z_tilde given y_tilde, whereas in BMSHJ2018, z_tilde is obtained
        # conditioned on x.
        y_tilde = tf.round(y)

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y_tilde),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
Esempio n. 3
0
def main():
    wandb.init(project="vae-comparison")
    wandb.config.update(args)
    log_step = 0

    # set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # set device
    use_gpu = args.use_gpu and torch.cuda.is_available()
    device = torch.device("cuda" if use_gpu else "cpu")
    print("training on {} device".format("cuda" if use_gpu else "cpu"))

    # load dataset
    train_loader, val_loader, test_loader = load_data(
        dataset=args.dataset,
        batch_size=args.batch_size,
        no_validation=args.no_validation,
        shuffle=args.shuffle,
        data_file=args.data_file)

    # define model or load checkpoint
    if args.train_from == '':
        print('--------------------------------')
        print("initializing new model")
        model = VAE(latent_dim=args.latent_dim)

    else:
        print('--------------------------------')
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']

    print('--------------------------------')
    print("model architecture")
    print(model)

    # set model for training
    model.to(device)
    model.train()

    # define optimizers and their schedulers
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_enc = torch.optim.Adam(model.enc.parameters(), lr=args.lr)
    optimizer_dec = torch.optim.Adam(model.dec.parameters(), lr=args.lr)
    lr_lambda = lambda count: 0.9
    lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer, lr_lambda=lr_lambda)
    lr_scheduler_enc = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer_enc, lr_lambda=lr_lambda)
    lr_scheduler_dec = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer_dec, lr_lambda=lr_lambda)

    # set beta KL scaling parameter
    if args.warmup == 0:
        beta_ten = torch.tensor(1.)
    else:
        beta_ten = torch.tensor(0.1)

    # set savae meta optimizer
    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(utils.variational_loss,
                              model,
                              update_params,
                              beta=beta_ten,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=1,
                              max_grad_norm=args.svi_max_grad_norm)

    # if test flag set, evaluate and exit
    if args.test == 1:
        beta_ten.data.fill_(1.)
        eval(test_loader, model, meta_optimizer, device)
        importance_sampling(data=test_loader,
                            model=model,
                            batch_size=args.batch_size,
                            meta_optimizer=meta_optimizer,
                            device=device,
                            nr_samples=20000,
                            test_mode=True,
                            verbose=True,
                            mode=args.test_type)
        exit()

    # initialize counters and stats
    epoch = 0
    t = 0
    best_val_metric = 100000000
    best_epoch = 0
    loss_stats = []
    # training loop
    C = torch.tensor(0., device=device)
    C_local = torch.zeros(args.batch_size * len(train_loader), device=device)
    epsilon = None
    step = 0
    while epoch < args.num_epochs:

        start_time = time.time()
        epoch += 1

        print('--------------------------------')
        print('starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        train_cdiv = 0.
        train_nll = 0.
        train_acc_rate = 0.
        num_examples = 0
        count_one_pixels = 0

        for b, datum in enumerate(train_loader):
            t += 1

            if args.warmup > 0:
                beta_ten.data.fill_(
                    torch.min(torch.tensor(1.), beta_ten + 1 /
                              (args.warmup * len(train_loader))).data)

            img, _ = datum
            img = torch.where(img < 0.5, torch.zeros_like(img),
                              torch.ones_like(img))
            if epoch == 1:
                count_one_pixels += torch.sum(img).item()
            img = img.to(device)

            optimizer.zero_grad()
            optimizer_enc.zero_grad()
            optimizer_dec.zero_grad()

            if args.model == 'svi':
                mean_svi = torch.zeros(args.batch_size,
                                       args.latent_dim,
                                       requires_grad=True,
                                       device=device)
                logvar_svi = torch.zeros(args.batch_size,
                                         args.latent_dim,
                                         requires_grad=True,
                                         device=device)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        img)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model.reparameterize(mean_svi_final.detach(),
                                                 logvar_svi_final.detach())
                preds = model.dec_forward(z_samples)
                nll_svi = utils.log_bernoulli_loss(preds, img)
                train_nll_svi += nll_svi.item() * args.batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.item() * args.batch_size
                var_loss = nll_svi + beta_ten.item() * kl_svi
                var_loss.backward()

                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()

            else:
                mean, logvar = model.enc_forward(img)
                z_samples = model.reparameterize(mean, logvar)
                preds = model.dec_forward(z_samples)
                nll_vae = utils.log_bernoulli_loss(preds, img)
                train_nll_vae += nll_vae.item() * args.batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                train_kl_vae += kl_vae.item() * args.batch_size

                if args.model == 'vae':
                    vae_loss = nll_vae + beta_ten.item() * kl_vae
                    vae_loss.backward()

                    optimizer.step()

                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = mean.clone().detach().requires_grad_(True)
                    logvar_svi = logvar.clone().detach().requires_grad_(True)

                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], img)
                    mean_svi_final, logvar_svi_final = var_params_svi

                    z_samples = model.reparameterize(mean_svi_final,
                                                     logvar_svi_final)
                    preds = model.dec_forward(z_samples)
                    nll_svi = utils.log_bernoulli_loss(preds, img)
                    train_nll_svi += nll_svi.item() * args.batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.item() * args.batch_size
                    var_loss = nll_svi + beta_ten.item() * kl_svi
                    var_loss.backward(retain_graph=True)
                    var_param_grads = meta_optimizer.backward(
                        [mean_svi_final.grad, logvar_svi_final.grad])
                    var_param_grads = torch.cat(var_param_grads, 1)
                    var_params.backward(var_param_grads)

                    if args.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()

                if args.model == "cdiv" or args.model == "cdiv_svgd":

                    pxz = utils.log_pxz(preds, img, z_samples)
                    first_term = torch.mean(pxz) + 0.5 * args.latent_dim
                    logqz = utils.log_normal_pdf(z_samples, mean, logvar)

                    if epoch == 7 and b == 0:  # switch to local variate control
                        C_local = torch.ones(
                            args.batch_size * len(train_loader),
                            device=device) * C

                    if args.model == "cdiv":
                        zt, samples, acc_rate, epsilon = hmc.hmc_vae(
                            z_samples.clone().detach().requires_grad_(),
                            model,
                            img,
                            epsilon=epsilon,
                            Burn=0,
                            T=args.num_hmc_iters,
                            adapt=0,
                            L=5)
                        train_acc_rate += torch.mean(
                            acc_rate) * args.batch_size
                    else:
                        mean_all = torch.repeat_interleave(
                            mean, args.num_svgd_particles, 0)
                        logvar_all = torch.repeat_interleave(
                            logvar, args.num_svgd_particles, 0)
                        img_all = torch.repeat_interleave(
                            img, args.num_svgd_particles, 0)
                        z_samples = mean_all + torch.randn(
                            args.num_svgd_particles * args.batch_size,
                            args.latent_dim,
                            device=device) * torch.exp(0.5 * logvar_all)
                        samples = svgd.svgd_batched(args.num_svgd_particles,
                                                    args.batch_size,
                                                    z_samples,
                                                    model,
                                                    img_all.view(-1, 784),
                                                    iter=args.num_svgd_iters)
                        z_ind = torch.randint(low=0, high=args.num_svgd_particles, size=(args.batch_size,),
                                              device=device) + \
                                torch.tensor(args.num_svgd_particles, device=device) * \
                                torch.arange(0, args.batch_size, device=device)
                        zt = samples[z_ind]

                    preds_zt = model.dec_forward(zt)

                    pxzt = utils.log_pxz(preds_zt, img, zt)
                    g_zt = pxzt + torch.sum(
                        0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1)

                    second_term = torch.mean(g_zt)
                    cdiv = -first_term + second_term
                    train_cdiv += cdiv.item() * args.batch_size
                    train_nll += -torch.mean(pxzt).item() * args.batch_size

                    if epoch <= 6:
                        loss = -first_term + torch.mean(
                            torch.sum(
                                0.5 *
                                ((zt - mean)**2) * torch.exp(-logvar), 1) +
                            (g_zt.detach() - C) * logqz)
                        if b == 0:
                            C = torch.mean(g_zt.detach())
                        else:
                            C = 0.9 * C + 0.1 * torch.mean(g_zt.detach())
                    else:
                        control = C_local[b * args.batch_size:(b + 1) *
                                          args.batch_size]
                        loss = -first_term + torch.mean(
                            torch.sum(
                                0.5 *
                                ((zt - mean)**2) * torch.exp(-logvar), 1) +
                            (g_zt.detach() - control) * logqz)
                        C_local[b * args.batch_size:(b + 1) * args.batch_size] = \
                            0.9 * C_local[b * args.batch_size:(b + 1) * args.batch_size] + 0.1 * g_zt.detach()

                    loss.backward(retain_graph=True)
                    optimizer_enc.step()

                    optimizer_dec.zero_grad()
                    torch.mean(-utils.log_pxz(preds_zt, img, zt)).backward()
                    optimizer_dec.step()

            if t % 15000 == 0:
                if args.model == "cdiv" or args.model == "cdiv_svgd":
                    lr_scheduler_enc.step()
                    lr_scheduler_dec.step()
                else:
                    lr_scheduler.step()

            num_examples += args.batch_size
            if b and (b + 1) % args.print_every == 0:
                step += 1

                print('--------------------------------')
                print('iteration: %d, epoch: %d, batch: %d/%d' %
                      (t, epoch, b + 1, len(train_loader)))
                if epoch > 1:
                    print('best epoch: %d: %.2f' %
                          (best_epoch, best_val_metric))
                print('throughput: %.2f examples/sec' %
                      (num_examples / (time.time() - start_time)))

                if args.model != 'svi':
                    print(
                        'train_VAE_NLL: %.2f, train_VAE_KL: %.4f, train_VAE_NLLBnd: %.2f'
                        % (train_nll_vae / num_examples,
                           train_kl_vae / num_examples,
                           (train_nll_vae + train_kl_vae) / num_examples))
                    wandb.log(
                        {
                            "train_vae_nll":
                            train_nll_vae / num_examples,
                            "train_vae_kl":
                            train_kl_vae / num_examples,
                            "train_vae_nll_bound":
                            (train_nll_vae + train_kl_vae) / num_examples,
                        },
                        step=log_step)

                if args.model == 'svi' or args.model == 'savae':
                    print(
                        'train_SVI_NLL: %.2f, train_SVI_KL: %.4f, train_SVI_NLLBnd: %.2f'
                        % (train_nll_svi / num_examples,
                           train_kl_svi / num_examples,
                           (train_nll_svi + train_kl_svi) / num_examples))
                    wandb.log(
                        {
                            "train_svi_nll":
                            train_nll_svi / num_examples,
                            "train_svi_kl":
                            train_kl_svi / num_examples,
                            "train_svi_nll_bound":
                            (train_nll_svi + train_kl_svi) / num_examples,
                        },
                        step=log_step)

                if args.model == "cdiv" or args.model == "cdiv_svgd":
                    print(
                        'train_NLL: %.2f, train_CDIV: %.4f' %
                        (train_nll / num_examples, train_cdiv / num_examples))
                    wandb.log(
                        {
                            "train_nll": train_nll / num_examples,
                            "train_cdiv": train_cdiv / num_examples,
                        },
                        step=log_step)

                    if args.model == "cdiv":
                        print('train_average_acc_rate: %.3f' %
                              (train_acc_rate / num_examples))
                        wandb.log(
                            {
                                "train_average_acc_rate":
                                train_acc_rate / num_examples,
                            },
                            step=log_step)
                log_step += 1

        if epoch == 1:
            print('--------------------------------')
            print("count of pixels 1 in training data: {}".format(
                count_one_pixels))
            wandb.log({"dataset_pixel_check": count_one_pixels}, step=log_step)
        if args.no_validation:
            print('--------------------------------')
            print("[validation disabled!]")
        else:
            val_metric = eval(val_loader, model, meta_optimizer, device, epoch,
                              epsilon, log_step)

        checkpoint = {
            'args': args.__dict__,
            'model': model,
            'loss_stats': loss_stats
        }
        torch.save(checkpoint, args.checkpoint_path + "_last.pt")
        if not args.no_validation:
            loss_stats.append(val_metric)
            if val_metric < best_val_metric:
                best_val_metric = val_metric
                best_epoch = epoch
                print('saving checkpoint to %s' %
                      (args.checkpoint_path + "_best.pt"))
                torch.save(checkpoint, args.checkpoint_path + "_best.pt")
def importance_sampling(data,
                        model,
                        batch_size,
                        meta_optimizer,
                        device,
                        nr_samples,
                        test_mode,
                        verbose=False,
                        mode="vae",
                        adapt=False,
                        epsilon=None,
                        log_step=0):
    """
    Computes importance sampling estimate of data probability.

    :param data: datapoints
    :param model: VAE model
    :param batch_size: batch size
    :param meta_optimizer: savae meta optimizer
    :param device: torch device
    :param nr_samples: number of samples for importance sampling
    :param test_mode: true if run on test data
    :param verbose: true if current estimate should be outputted
    :param mode: mode of evaluation - way of generating sampling distribution
    :param adapt: adapt parameter for hmc
    :param epsilon: epsilon parameter for hmc
    :param log_step: step for metrics logging
    :return: importance sampling estimate of data probability
    """

    model.eval()
    results = torch.zeros(batch_size * len(data))
    disp_const = 2 * math.log(1.2)
    t = 0
    S = nr_samples

    for datum in data:
        img_batch, _ = datum
        img_batch = img_batch.to(device)
        for img_single in img_batch:

            img_single = torch.where(img_single < 0.5,
                                     torch.zeros_like(img_single),
                                     torch.ones_like(img_single))
            img = img_single.repeat(S, 1, 1)

            if mode == 'svi':
                mean_svi = 0.1 * torch.zeros(batch_size,
                                             model.latent_dim,
                                             device=device,
                                             requires_grad=True)
                logvar_svi = 0.1 * torch.zeros(batch_size,
                                               model.latent_dim,
                                               device=device,
                                               requires_grad=True)
                var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi],
                                                        img)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model.reparameterize(mean_svi_final.detach(),
                                                 logvar_svi_final.detach())
                preds = model.dec_forward(z_samples)
                mean, logvar = mean_svi_final, logvar_svi_final

            elif mode == 'cdiv' or mode == "cdiv_svgd":
                mean, logvar = model.enc_forward(img_single.unsqueeze(0))
                z_samples = model.reparameterize(mean, disp_const + logvar)
                if mode == "cdiv":
                    z, samples, acc_rate, _ = hmc.hmc_vae(
                        z_samples,
                        model,
                        img_single.unsqueeze(0),
                        epsilon=epsilon,
                        Burn=0,
                        T=300,
                        adapt=adapt,
                        L=5)
                    mean = torch.mean(samples, 2)
                    var = torch.var(samples, 2)
                else:
                    ind = 0
                    z_samples = mean[ind].unsqueeze(0) + \
                                torch.randn(300, mean.size(1), device=device) * \
                                torch.exp(0.5 * logvar[ind]).unsqueeze(0)
                    samples = svgd.svgd(z_samples,
                                        model,
                                        img_single.view(-1, 784),
                                        iter=20)
                    mean = torch.mean(samples, 0)
                    var = torch.var(samples, 0)

                eps = 0.00000001
                var = torch.where(var < eps, torch.ones_like(var) * eps, var)
                logvar = torch.log(var)
                z_samples = model.reparameterize(
                    mean.repeat(S, 1), disp_const + logvar.repeat(S, 1))
                preds = model.dec_forward(z_samples)
            else:
                mean, logvar = model.enc_forward(img)
                z_samples = model.reparameterize(mean, disp_const + logvar)
                preds = model.dec_forward(z_samples)
                if mode == 'savae':
                    mean_svi = mean.data.clone().detach().requires_grad_(True)
                    logvar_svi = logvar.clone().detach().requires_grad_(True)
                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], img)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model.reparameterize(
                        mean_svi_final, disp_const + logvar_svi_final)
                    preds = model.dec_forward(z_samples.detach())
                    mean, logvar = mean_svi_final, logvar_svi_final

            log_pxz = utils.log_pxz(preds, img, z_samples)
            logqz = utils.log_normal_pdf(z_samples, mean, disp_const + logvar)
            results[t] = (torch.logsumexp(log_pxz - logqz, 0) -
                          math.log(S)).detach()
            current_mean = torch.sum(results) / (t + 1)
            if verbose and t % 100 == 0 and t:
                print("---> IS estimate after {} examples".format(t))
                print(current_mean.item())

            if test_mode and t > 100:
                wandb.log({"IS_mean": current_mean})
            t += 1

    if test_mode:
        wandb.log({"test_importance_sampling_nll": current_mean})
        print('--------------------------------')
        print("IS {} samples per datapoint".format(S))
        print("test IS estimate: {}".format(current_mean.item()))
    else:
        wandb.log({"val_importance_sampling_nll": current_mean}, step=log_step)
        print('--------------------------------')
        print("IS {} samples per datapoint".format(S))
        print("val IS estimate: {}".format(current_mean.item()))

    model.train()
    return current_mean.item()
Esempio n. 5
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    from utils import log_normal_pdf
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3))

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)
    y_tilde = tf.round(y)

    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde),
                                          num_or_size_splits=2,
                                          axis=-1)
    z_mean = tf.placeholder(
        'float32',
        z_mean_init.shape)  # initialize to inference network results
    z_logvar = tf.placeholder('float32', z_logvar_init.shape)

    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    bpp_back = tf.reduce_sum(
        -log_q_z_tilde, axis=axes_except_batch) / (np.log(2) * img_num_pixels)
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp - bpp_back  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    local_gradients = tf.gradients(train_bpp, [z_mean, z_logvar])

    # Bring both images back to 0..255 range.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp', 'est_bpp_back'
        ]
        eval_tensors = [
            mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back
        ]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init],
                    feed_dict=x_feed_dict)  # np arrays

                opt_obj_hist = []
                opt_grad_hist = []
                lr = 0.005
                local_its = 1000
                from adam import Adam
                np_adam_optimizer = Adam(lr=lr)
                for it in range(local_its):
                    grads, obj = sess.run([local_gradients, train_bpp],
                                          feed_dict={
                                              z_mean: z_mean_cur,
                                              z_logvar: z_logvar_cur,
                                              **x_feed_dict
                                          })
                    z_mean_cur, z_logvar_cur = np_adam_optimizer.update(
                        [z_mean_cur, z_logvar_cur], grads)
                    if it % 100 == 0:
                        print('negative local ELBO', obj)
                    opt_obj_hist.append(obj)
                    opt_grad_hist.append(np.mean(np.abs(grads)))
                print()

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         z_mean: z_mean_cur,
                                         z_logvar: z_logvar_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension
        save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file)
        if script_name != trained_script_name:
            save_file = 'rd-%s+%s-input=%s.npz' % (script_name, args.runname,
                                                   input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
Esempio n. 6
0
def build_graph(args, x, training=True):
    """
    Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3].
    Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest.
    During training we sample from box-shaped posteriors; during compression this is approximated by rounding.
    """
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)

    # z_tilde ~ q(z_tilde | x) = q(z_tilde | h_a(y))
    z_mean, z_logvar = tf.split(hyper_analysis_transform(y),
                                num_or_size_splits=2,
                                axis=-1)
    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean
    from utils import log_normal_pdf
    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3))
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    if training:
        sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5)
    if not training:  # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
        y_shape = tf.shape(y)
        mu = mu[:, :y_shape[1], :y_shape[2], :]
        sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of
    # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    # Note that at test/compression time, the resulting y_tilde doesn't simply
    # equal round(y); instead, the conditional_bottleneck does something
    # smarter and slightly more optimal: y_hat=floor(y + 0.5 - prior_mean), so
    # that the mean (mu) of the prior coincides with one of the quantization bins.
    y_tilde, y_likelihoods = conditional_bottleneck(y, training=training)

    x_tilde = synthesis_transform(y_tilde)
    if not training:
        x_shape = tf.shape(x)
        x_tilde = x_tilde[:, :x_shape[1], :x_shape[
            2], :]  # crop reconstruction to have the same shape as input

    return locals()
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    from utils import log_normal_pdf
    from learned_prior import BMSHJ2018Prior
    hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3))

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      num_output_filters=2 *
                                                      args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    # entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial optimization (where we still have access to x)
    # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot
    # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP.
    # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with
    # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ...
    # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0.
    import tensorflow_probability as tfp
    T = tf.placeholder('float32', shape=[], name='temperature')
    y_init = analysis_transform(x)
    y = tf.placeholder('float32', y_init.shape)
    y_floor = tf.floor(y)
    y_ceil = tf.ceil(y)
    y_bds = tf.stack([y_floor, y_ceil], axis=-1)
    epsilon = 1e-5
    logits = tf.stack(
        [
            -tf.math.atanh(
                tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T,
            -tf.math.atanh(
                tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T
        ],
        axis=-1
    )  # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0
    rounding_dist = tfp.distributions.RelaxedOneHotCategorical(
        T,
        logits=logits)  # technically we can use a different temperature here
    sample_concrete = rounding_dist.sample()
    y_tilde = tf.reduce_sum(y_bds * sample_concrete,
                            axis=-1)  # inner product in last dim
    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # z_tilde ~ q(z_tilde | h_a(\tilde y))
    z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde),
                                          num_or_size_splits=2,
                                          axis=-1)
    z_mean = tf.placeholder(
        'float32',
        z_mean_init.shape)  # initialize to inference network results
    z_logvar = tf.placeholder('float32', z_logvar_init.shape)

    eps = tf.random.normal(shape=tf.shape(z_mean))
    z_tilde = eps * tf.exp(z_logvar * .5) + z_mean

    log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar)  # bits back

    # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods")
    z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False)
    z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound)

    # compute parameters of p(y_tilde|z_tilde)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive

    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    #### END build compression graph ####

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    batch_log_q_z_tilde = tf.reduce_sum(log_q_z_tilde, axis=axes_except_batch)
    bpp_back = -batch_log_q_z_tilde / (np.log(2) * img_num_pixels)
    batch_log_cond_p_y_tilde = tf.reduce_sum(tf.log(y_likelihoods),
                                             axis=axes_except_batch)
    y_bpp = -batch_log_cond_p_y_tilde / (np.log(2) * img_num_pixels)
    batch_log_p_z_tilde = tf.reduce_sum(tf.log(z_likelihoods),
                                        axis=axes_except_batch)
    z_bpp = -batch_log_p_z_tilde / (np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp - bpp_back  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z_mean, z_logvar])
    r_gradients = tf.gradients(train_bpp, [z_mean, z_logvar])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp', 'est_bpp_back'
        ]
        eval_tensors = [
            mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back
        ]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        log_itv = 100
        rd_lr = 0.005
        # rd_opt_its = args.sga_its
        rd_opt_its = 2000
        annealing_scheme = 'exp0'
        annealing_rate = args.annealing_rate  # default annealing_rate = 1e-3
        t0 = args.t0  # default t0 = 700
        T_ub = 0.5  # max/initial temperature
        from utils import annealed_temperature
        r_lr = 0.003
        r_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur = sess.run(y_init, feed_dict=x_feed_dict)  # np arrays
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init], feed_dict={y_tilde: y_cur})
                rd_loss_hist = []
                adam_optimizer = Adam(lr=rd_lr)

                opt_record = {
                    'its': [],
                    'T': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    temperature = annealed_temperature(it,
                                                       r=annealing_rate,
                                                       ub=T_ub,
                                                       scheme=annealing_scheme,
                                                       t0=t0)
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z_mean: z_mean_cur,
                            z_logvar: z_logvar_cur,
                            **x_feed_dict, T: temperature
                        })
                    y_cur, z_mean_cur, z_logvar_cur = adam_optimizer.update(
                        [y_cur, z_mean_cur, z_logvar_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: np.round(y_cur),
                                    z_mean: z_mean_cur,
                                    z_logvar: z_logvar_cur,
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_, rd_loss_after_rounding,
                                   bpp_after_rounding, psnr_after_rounding))
                        else:
                            print(
                                'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, temperature, obj, mse_, train_bpp_,
                                   psnr_))
                    rd_loss_hist.append(obj)
                print()

                # 2. Fix y_tilde, perform rate optimization w.r.t. z_mean and z_logvar.
                y_tilde_cur = np.round(
                    y_cur)  # this is the latents we end up transmitting
                # rate_feed_dict = {y_tilde: y_tilde_cur, **x_feed_dict}
                rate_feed_dict = {y_tilde: y_tilde_cur}
                np.random.seed(seed)
                tf.set_random_seed(seed)
                print('----Rate Optimization----')
                # Reinitialize based on the value of y_tilde
                z_mean_cur, z_logvar_cur = sess.run(
                    [z_mean_init, z_logvar_init],
                    feed_dict=rate_feed_dict)  # np arrays

                r_loss_hist = []
                # rate_grad_hist = []

                adam_optimizer = Adam(lr=r_lr)
                for it in range(r_opt_its):
                    grads, obj = sess.run(
                        [r_gradients, train_bpp],
                        feed_dict={
                            z_mean: z_mean_cur,
                            z_logvar: z_logvar_cur,
                            **rate_feed_dict
                        })
                    z_mean_cur, z_logvar_cur = adam_optimizer.update(
                        [z_mean_cur, z_logvar_cur], grads)
                    if it % log_itv == 0 or it + 1 == r_opt_its:
                        print('it=', it, '\trate=', obj)
                    r_loss_hist.append(obj)
                    # rate_grad_hist.append(np.mean(np.abs(grads)))
                print()

                # fig, axes = plt.subplots(nrows=2, sharex=True)
                # axes[0].plot(rd_loss_hist)
                # axes[0].set_ylabel('RD loss')
                # axes[1].plot(r_loss_hist)
                # axes[1].set_ylabel('Rate loss')
                # axes[1].set_xlabel('SGD iterations')
                # plt.savefig('plots/local_q_opt_hist-%s-input=%s-b=%d.png' %
                #             (args.runname, os.path.basename(args.input_file), batch_idx))

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_tilde_cur,
                                         z_mean: z_mean_cur,
                                         z_logvar: z_logvar_cur,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension
        save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file)
        if script_name != trained_script_name:
            save_file = 'rd-%s-lmbda=%g+%s-input=%s.npz' % (
                script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))