def vae_cross_ent_loss(model, x, beta=1.): mean, logvar = model.encode(x) z = model.reparameterize(mean, logvar) x_logit = model.decode(z) logpx_z = cross_ent_loss(x_logit, x) logpz = utils.log_normal_pdf(z, 0., 0.) logqz_x = utils.log_normal_pdf(z, mean, logvar) kld = logqz_x - logpz return -tf.reduce_mean(logpx_z - beta*kld)
def build_graph(args, x, training=True): """ Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3]. Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest. During training we sample from box-shaped posteriors; during compression this is approximated by rounding. """ # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) # y_tilde ~ q(y_tilde | y = g_a(x)) half = tf.constant(.5, dtype=y.dtype) if training: noise = tf.random.uniform(tf.shape(y), -half, half) y_tilde = y + noise else: # Approximately sample from q(y_tilde|x) by rounding. We can't be smart and do y_hat=floor(y + 0.5 - prior_mean) as # in Balle's model (ultimately implemented by conditional_bottleneck._quantize), because we don't have the prior # p(y_tilde | z_tilde) yet; in bb we have to sample z_tilde given y_tilde, whereas in BMSHJ2018, z_tilde is obtained # conditioned on x. y_tilde = tf.round(y) # z_tilde ~ q(z_tilde | h_a(\tilde y)) z_mean, z_logvar = tf.split(hyper_analysis_transform(y_tilde), num_or_size_splits=2, axis=-1) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean from utils import log_normal_pdf log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3)) z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive if training: sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5) if not training: # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) x_tilde = synthesis_transform(y_tilde) if not training: x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input return locals()
def main(): wandb.init(project="vae-comparison") wandb.config.update(args) log_step = 0 # set random seeds np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # set device use_gpu = args.use_gpu and torch.cuda.is_available() device = torch.device("cuda" if use_gpu else "cpu") print("training on {} device".format("cuda" if use_gpu else "cpu")) # load dataset train_loader, val_loader, test_loader = load_data( dataset=args.dataset, batch_size=args.batch_size, no_validation=args.no_validation, shuffle=args.shuffle, data_file=args.data_file) # define model or load checkpoint if args.train_from == '': print('--------------------------------') print("initializing new model") model = VAE(latent_dim=args.latent_dim) else: print('--------------------------------') print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print('--------------------------------') print("model architecture") print(model) # set model for training model.to(device) model.train() # define optimizers and their schedulers optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer_enc = torch.optim.Adam(model.enc.parameters(), lr=args.lr) optimizer_dec = torch.optim.Adam(model.dec.parameters(), lr=args.lr) lr_lambda = lambda count: 0.9 lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR( optimizer, lr_lambda=lr_lambda) lr_scheduler_enc = torch.optim.lr_scheduler.MultiplicativeLR( optimizer_enc, lr_lambda=lr_lambda) lr_scheduler_dec = torch.optim.lr_scheduler.MultiplicativeLR( optimizer_dec, lr_lambda=lr_lambda) # set beta KL scaling parameter if args.warmup == 0: beta_ten = torch.tensor(1.) else: beta_ten = torch.tensor(0.1) # set savae meta optimizer update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(utils.variational_loss, model, update_params, beta=beta_ten, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=1, max_grad_norm=args.svi_max_grad_norm) # if test flag set, evaluate and exit if args.test == 1: beta_ten.data.fill_(1.) eval(test_loader, model, meta_optimizer, device) importance_sampling(data=test_loader, model=model, batch_size=args.batch_size, meta_optimizer=meta_optimizer, device=device, nr_samples=20000, test_mode=True, verbose=True, mode=args.test_type) exit() # initialize counters and stats epoch = 0 t = 0 best_val_metric = 100000000 best_epoch = 0 loss_stats = [] # training loop C = torch.tensor(0., device=device) C_local = torch.zeros(args.batch_size * len(train_loader), device=device) epsilon = None step = 0 while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('--------------------------------') print('starting epoch %d' % epoch) train_nll_vae = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. train_cdiv = 0. train_nll = 0. train_acc_rate = 0. num_examples = 0 count_one_pixels = 0 for b, datum in enumerate(train_loader): t += 1 if args.warmup > 0: beta_ten.data.fill_( torch.min(torch.tensor(1.), beta_ten + 1 / (args.warmup * len(train_loader))).data) img, _ = datum img = torch.where(img < 0.5, torch.zeros_like(img), torch.ones_like(img)) if epoch == 1: count_one_pixels += torch.sum(img).item() img = img.to(device) optimizer.zero_grad() optimizer_enc.zero_grad() optimizer_dec.zero_grad() if args.model == 'svi': mean_svi = torch.zeros(args.batch_size, args.latent_dim, requires_grad=True, device=device) logvar_svi = torch.zeros(args.batch_size, args.latent_dim, requires_grad=True, device=device) var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model.reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model.dec_forward(z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.item() * args.batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.item() * args.batch_size var_loss = nll_svi + beta_ten.item() * kl_svi var_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() else: mean, logvar = model.enc_forward(img) z_samples = model.reparameterize(mean, logvar) preds = model.dec_forward(z_samples) nll_vae = utils.log_bernoulli_loss(preds, img) train_nll_vae += nll_vae.item() * args.batch_size kl_vae = utils.kl_loss_diag(mean, logvar) train_kl_vae += kl_vae.item() * args.batch_size if args.model == 'vae': vae_loss = nll_vae + beta_ten.item() * kl_vae vae_loss.backward() optimizer.step() if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = mean.clone().detach().requires_grad_(True) logvar_svi = logvar.clone().detach().requires_grad_(True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model.reparameterize(mean_svi_final, logvar_svi_final) preds = model.dec_forward(z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.item() * args.batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.item() * args.batch_size var_loss = nll_svi + beta_ten.item() * kl_svi var_loss.backward(retain_graph=True) var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad]) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.model == "cdiv" or args.model == "cdiv_svgd": pxz = utils.log_pxz(preds, img, z_samples) first_term = torch.mean(pxz) + 0.5 * args.latent_dim logqz = utils.log_normal_pdf(z_samples, mean, logvar) if epoch == 7 and b == 0: # switch to local variate control C_local = torch.ones( args.batch_size * len(train_loader), device=device) * C if args.model == "cdiv": zt, samples, acc_rate, epsilon = hmc.hmc_vae( z_samples.clone().detach().requires_grad_(), model, img, epsilon=epsilon, Burn=0, T=args.num_hmc_iters, adapt=0, L=5) train_acc_rate += torch.mean( acc_rate) * args.batch_size else: mean_all = torch.repeat_interleave( mean, args.num_svgd_particles, 0) logvar_all = torch.repeat_interleave( logvar, args.num_svgd_particles, 0) img_all = torch.repeat_interleave( img, args.num_svgd_particles, 0) z_samples = mean_all + torch.randn( args.num_svgd_particles * args.batch_size, args.latent_dim, device=device) * torch.exp(0.5 * logvar_all) samples = svgd.svgd_batched(args.num_svgd_particles, args.batch_size, z_samples, model, img_all.view(-1, 784), iter=args.num_svgd_iters) z_ind = torch.randint(low=0, high=args.num_svgd_particles, size=(args.batch_size,), device=device) + \ torch.tensor(args.num_svgd_particles, device=device) * \ torch.arange(0, args.batch_size, device=device) zt = samples[z_ind] preds_zt = model.dec_forward(zt) pxzt = utils.log_pxz(preds_zt, img, zt) g_zt = pxzt + torch.sum( 0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1) second_term = torch.mean(g_zt) cdiv = -first_term + second_term train_cdiv += cdiv.item() * args.batch_size train_nll += -torch.mean(pxzt).item() * args.batch_size if epoch <= 6: loss = -first_term + torch.mean( torch.sum( 0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1) + (g_zt.detach() - C) * logqz) if b == 0: C = torch.mean(g_zt.detach()) else: C = 0.9 * C + 0.1 * torch.mean(g_zt.detach()) else: control = C_local[b * args.batch_size:(b + 1) * args.batch_size] loss = -first_term + torch.mean( torch.sum( 0.5 * ((zt - mean)**2) * torch.exp(-logvar), 1) + (g_zt.detach() - control) * logqz) C_local[b * args.batch_size:(b + 1) * args.batch_size] = \ 0.9 * C_local[b * args.batch_size:(b + 1) * args.batch_size] + 0.1 * g_zt.detach() loss.backward(retain_graph=True) optimizer_enc.step() optimizer_dec.zero_grad() torch.mean(-utils.log_pxz(preds_zt, img, zt)).backward() optimizer_dec.step() if t % 15000 == 0: if args.model == "cdiv" or args.model == "cdiv_svgd": lr_scheduler_enc.step() lr_scheduler_dec.step() else: lr_scheduler.step() num_examples += args.batch_size if b and (b + 1) % args.print_every == 0: step += 1 print('--------------------------------') print('iteration: %d, epoch: %d, batch: %d/%d' % (t, epoch, b + 1, len(train_loader))) if epoch > 1: print('best epoch: %d: %.2f' % (best_epoch, best_val_metric)) print('throughput: %.2f examples/sec' % (num_examples / (time.time() - start_time))) if args.model != 'svi': print( 'train_VAE_NLL: %.2f, train_VAE_KL: %.4f, train_VAE_NLLBnd: %.2f' % (train_nll_vae / num_examples, train_kl_vae / num_examples, (train_nll_vae + train_kl_vae) / num_examples)) wandb.log( { "train_vae_nll": train_nll_vae / num_examples, "train_vae_kl": train_kl_vae / num_examples, "train_vae_nll_bound": (train_nll_vae + train_kl_vae) / num_examples, }, step=log_step) if args.model == 'svi' or args.model == 'savae': print( 'train_SVI_NLL: %.2f, train_SVI_KL: %.4f, train_SVI_NLLBnd: %.2f' % (train_nll_svi / num_examples, train_kl_svi / num_examples, (train_nll_svi + train_kl_svi) / num_examples)) wandb.log( { "train_svi_nll": train_nll_svi / num_examples, "train_svi_kl": train_kl_svi / num_examples, "train_svi_nll_bound": (train_nll_svi + train_kl_svi) / num_examples, }, step=log_step) if args.model == "cdiv" or args.model == "cdiv_svgd": print( 'train_NLL: %.2f, train_CDIV: %.4f' % (train_nll / num_examples, train_cdiv / num_examples)) wandb.log( { "train_nll": train_nll / num_examples, "train_cdiv": train_cdiv / num_examples, }, step=log_step) if args.model == "cdiv": print('train_average_acc_rate: %.3f' % (train_acc_rate / num_examples)) wandb.log( { "train_average_acc_rate": train_acc_rate / num_examples, }, step=log_step) log_step += 1 if epoch == 1: print('--------------------------------') print("count of pixels 1 in training data: {}".format( count_one_pixels)) wandb.log({"dataset_pixel_check": count_one_pixels}, step=log_step) if args.no_validation: print('--------------------------------') print("[validation disabled!]") else: val_metric = eval(val_loader, model, meta_optimizer, device, epoch, epsilon, log_step) checkpoint = { 'args': args.__dict__, 'model': model, 'loss_stats': loss_stats } torch.save(checkpoint, args.checkpoint_path + "_last.pt") if not args.no_validation: loss_stats.append(val_metric) if val_metric < best_val_metric: best_val_metric = val_metric best_epoch = epoch print('saving checkpoint to %s' % (args.checkpoint_path + "_best.pt")) torch.save(checkpoint, args.checkpoint_path + "_best.pt")
def importance_sampling(data, model, batch_size, meta_optimizer, device, nr_samples, test_mode, verbose=False, mode="vae", adapt=False, epsilon=None, log_step=0): """ Computes importance sampling estimate of data probability. :param data: datapoints :param model: VAE model :param batch_size: batch size :param meta_optimizer: savae meta optimizer :param device: torch device :param nr_samples: number of samples for importance sampling :param test_mode: true if run on test data :param verbose: true if current estimate should be outputted :param mode: mode of evaluation - way of generating sampling distribution :param adapt: adapt parameter for hmc :param epsilon: epsilon parameter for hmc :param log_step: step for metrics logging :return: importance sampling estimate of data probability """ model.eval() results = torch.zeros(batch_size * len(data)) disp_const = 2 * math.log(1.2) t = 0 S = nr_samples for datum in data: img_batch, _ = datum img_batch = img_batch.to(device) for img_single in img_batch: img_single = torch.where(img_single < 0.5, torch.zeros_like(img_single), torch.ones_like(img_single)) img = img_single.repeat(S, 1, 1) if mode == 'svi': mean_svi = 0.1 * torch.zeros(batch_size, model.latent_dim, device=device, requires_grad=True) logvar_svi = 0.1 * torch.zeros(batch_size, model.latent_dim, device=device, requires_grad=True) var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model.reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model.dec_forward(z_samples) mean, logvar = mean_svi_final, logvar_svi_final elif mode == 'cdiv' or mode == "cdiv_svgd": mean, logvar = model.enc_forward(img_single.unsqueeze(0)) z_samples = model.reparameterize(mean, disp_const + logvar) if mode == "cdiv": z, samples, acc_rate, _ = hmc.hmc_vae( z_samples, model, img_single.unsqueeze(0), epsilon=epsilon, Burn=0, T=300, adapt=adapt, L=5) mean = torch.mean(samples, 2) var = torch.var(samples, 2) else: ind = 0 z_samples = mean[ind].unsqueeze(0) + \ torch.randn(300, mean.size(1), device=device) * \ torch.exp(0.5 * logvar[ind]).unsqueeze(0) samples = svgd.svgd(z_samples, model, img_single.view(-1, 784), iter=20) mean = torch.mean(samples, 0) var = torch.var(samples, 0) eps = 0.00000001 var = torch.where(var < eps, torch.ones_like(var) * eps, var) logvar = torch.log(var) z_samples = model.reparameterize( mean.repeat(S, 1), disp_const + logvar.repeat(S, 1)) preds = model.dec_forward(z_samples) else: mean, logvar = model.enc_forward(img) z_samples = model.reparameterize(mean, disp_const + logvar) preds = model.dec_forward(z_samples) if mode == 'savae': mean_svi = mean.data.clone().detach().requires_grad_(True) logvar_svi = logvar.clone().detach().requires_grad_(True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model.reparameterize( mean_svi_final, disp_const + logvar_svi_final) preds = model.dec_forward(z_samples.detach()) mean, logvar = mean_svi_final, logvar_svi_final log_pxz = utils.log_pxz(preds, img, z_samples) logqz = utils.log_normal_pdf(z_samples, mean, disp_const + logvar) results[t] = (torch.logsumexp(log_pxz - logqz, 0) - math.log(S)).detach() current_mean = torch.sum(results) / (t + 1) if verbose and t % 100 == 0 and t: print("---> IS estimate after {} examples".format(t)) print(current_mean.item()) if test_mode and t > 100: wandb.log({"IS_mean": current_mean}) t += 1 if test_mode: wandb.log({"test_importance_sampling_nll": current_mean}) print('--------------------------------') print("IS {} samples per datapoint".format(S)) print("test IS estimate: {}".format(current_mean.item())) else: wandb.log({"val_importance_sampling_nll": current_mean}, step=log_step) print('--------------------------------') print("IS {} samples per datapoint".format(S)) print("val IS estimate: {}".format(current_mean.item())) model.train() return current_mean.item()
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### from utils import log_normal_pdf from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3)) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) y_tilde = tf.round(y) x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # z_tilde ~ q(z_tilde | h_a(\tilde y)) z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde), num_or_size_splits=2, axis=-1) z_mean = tf.placeholder( 'float32', z_mean_init.shape) # initialize to inference network results z_logvar = tf.placeholder('float32', z_logvar_init.shape) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] bpp_back = tf.reduce_sum( -log_q_z_tilde, axis=axes_except_batch) / (np.log(2) * img_num_pixels) y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / ( np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp - bpp_back # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) local_gradients = tf.gradients(train_bpp, [z_mean, z_logvar]) # Bring both images back to 0..255 range. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp', 'est_bpp_back' ] eval_tensors = [ mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back ] all_results_arrs = {key: [] for key in eval_fields } # append across all batches batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} z_mean_cur, z_logvar_cur = sess.run( [z_mean_init, z_logvar_init], feed_dict=x_feed_dict) # np arrays opt_obj_hist = [] opt_grad_hist = [] lr = 0.005 local_its = 1000 from adam import Adam np_adam_optimizer = Adam(lr=lr) for it in range(local_its): grads, obj = sess.run([local_gradients, train_bpp], feed_dict={ z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) z_mean_cur, z_logvar_cur = np_adam_optimizer.update( [z_mean_cur, z_logvar_cur], grads) if it % 100 == 0: print('negative local ELBO', obj) opt_obj_hist.append(obj) opt_grad_hist.append(np.mean(np.abs(grads))) print() # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file) if script_name != trained_script_name: save_file = 'rd-%s+%s-input=%s.npz' % (script_name, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))
def build_graph(args, x, training=True): """ Build the computational graph of the model. x should be a float tensor of shape [batch, H, W, 3]. Given original image x, the model computes a lossy reconstruction x_tilde and various other quantities of interest. During training we sample from box-shaped posteriors; during compression this is approximated by rounding. """ # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) # z_tilde ~ q(z_tilde | x) = q(z_tilde | h_a(y)) z_mean, z_logvar = tf.split(hyper_analysis_transform(y), num_or_size_splits=2, axis=-1) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean from utils import log_normal_pdf log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(z_tilde.shape[-1], dims=(3, 3, 3)) z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive if training: sigma = math_ops.upper_bound(sigma, variance_upperbound**0.5) if not training: # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # sample y_tilde from q(y_tilde|x) = U(y-0.5, y+0.5) = U(g_a(x)-0.5, g_a(x)+0.5), and then compute the pdf of # y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) # Note that at test/compression time, the resulting y_tilde doesn't simply # equal round(y); instead, the conditional_bottleneck does something # smarter and slightly more optimal: y_hat=floor(y + 0.5 - prior_mean), so # that the mean (mu) of the prior coincides with one of the quantization bins. y_tilde, y_likelihoods = conditional_bottleneck(y, training=training) x_tilde = synthesis_transform(y_tilde) if not training: x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input return locals()
def compress(args): """Compresses an image, or a batch of images of the same shape in npy format.""" from configs import get_eval_batch_size if args.input_file.endswith('.npy'): # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3] X = np.load(args.input_file) else: # Load input image and add batch dimension. from PIL import Image x = np.asarray(Image.open(args.input_file).convert('RGB')) X = x[None, ...] num_images = int(X.shape[0]) img_num_pixels = int(np.prod(X.shape[1:-1])) X = X.astype('float32') X /= 255. eval_batch_size = get_eval_batch_size(img_num_pixels) dataset = tf.data.Dataset.from_tensor_slices(X) dataset = dataset.batch(batch_size=eval_batch_size) # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like # sess.run([op1, op2, ...]). # x = dataset.make_one_shot_iterator().get_next() x_next = dataset.make_one_shot_iterator().get_next() x_ph = x = tf.placeholder( 'float32', (None, *X.shape[1:])) # keep a reference around for feed_dict #### BEGIN build compression graph #### from utils import log_normal_pdf from learned_prior import BMSHJ2018Prior hyper_prior = BMSHJ2018Prior(args.num_filters, dims=(3, 3, 3)) # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) synthesis_transform = SynthesisTransform(args.num_filters) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, num_output_filters=2 * args.num_filters) hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters, num_output_filters=2 * args.num_filters) # entropy_bottleneck = tfc.EntropyBottleneck() # Initial optimization (where we still have access to x) # Soft-to-hard rounding with Gumbel-softmax trick; for each element of z_tilde, let R be a 2D auxiliary one-hot # random vector, such that R=[1, 0] means rounding DOWN and [0, 1] means rounding UP. # Let the logits of each outcome be -(z - z_floor) / T and -(z_ceil - z) / T (i.e., Boltzmann distribution with # energies (z - floor(z)) and (ceil(z) - z), so p(R==[1,0]) = softmax((z - z_floor) / T), ... # Let z_tilde = p(R==[1,0]) * floor(z) + p(R==[0,1]) * ceil(z), so z_tilde -> round(z) as T -> 0. import tensorflow_probability as tfp T = tf.placeholder('float32', shape=[], name='temperature') y_init = analysis_transform(x) y = tf.placeholder('float32', y_init.shape) y_floor = tf.floor(y) y_ceil = tf.ceil(y) y_bds = tf.stack([y_floor, y_ceil], axis=-1) epsilon = 1e-5 logits = tf.stack( [ -tf.math.atanh( tf.clip_by_value(y - y_floor, -1 + epsilon, 1 - epsilon)) / T, -tf.math.atanh( tf.clip_by_value(y_ceil - y, -1 + epsilon, 1 - epsilon)) / T ], axis=-1 ) # last dim are logits for DOWN or UP; clip to prevent NaN as temperature -> 0 rounding_dist = tfp.distributions.RelaxedOneHotCategorical( T, logits=logits) # technically we can use a different temperature here sample_concrete = rounding_dist.sample() y_tilde = tf.reduce_sum(y_bds * sample_concrete, axis=-1) # inner product in last dim x_tilde = synthesis_transform(y_tilde) x_shape = tf.shape(x) x_tilde = x_tilde[:, :x_shape[1], :x_shape[ 2], :] # crop reconstruction to have the same shape as input # z_tilde ~ q(z_tilde | h_a(\tilde y)) z_mean_init, z_logvar_init = tf.split(hyper_analysis_transform(y_tilde), num_or_size_splits=2, axis=-1) z_mean = tf.placeholder( 'float32', z_mean_init.shape) # initialize to inference network results z_logvar = tf.placeholder('float32', z_logvar_init.shape) eps = tf.random.normal(shape=tf.shape(z_mean)) z_tilde = eps * tf.exp(z_logvar * .5) + z_mean log_q_z_tilde = log_normal_pdf(z_tilde, z_mean, z_logvar) # bits back # compute the pdf of z_tilde under the flexible (hyper)prior p(z_tilde) ("z_likelihoods") z_likelihoods = hyper_prior.pdf(z_tilde, stop_gradient=False) z_likelihoods = math_ops.lower_bound(z_likelihoods, likelihood_lowerbound) # compute parameters of p(y_tilde|z_tilde) mu, sigma = tf.split(hyper_synthesis_transform(z_tilde), num_or_size_splits=2, axis=-1) sigma = tf.exp(sigma) # make positive # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y y_shape = tf.shape(y_tilde) mu = mu[:, :y_shape[1], :y_shape[2], :] sigma = sigma[:, :y_shape[1], :y_shape[2], :] scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table, mean=mu) # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde) # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5) y_likelihoods = conditional_bottleneck._likelihood( y_tilde) # p(\tilde y | \tilde z) if conditional_bottleneck.likelihood_bound > 0: likelihood_bound = conditional_bottleneck.likelihood_bound y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound) #### END build compression graph #### # Total number of bits divided by number of pixels. # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y) axes_except_batch = list(range(1, len(x.shape))) # should be [1,2,3] batch_log_q_z_tilde = tf.reduce_sum(log_q_z_tilde, axis=axes_except_batch) bpp_back = -batch_log_q_z_tilde / (np.log(2) * img_num_pixels) batch_log_cond_p_y_tilde = tf.reduce_sum(tf.log(y_likelihoods), axis=axes_except_batch) y_bpp = -batch_log_cond_p_y_tilde / (np.log(2) * img_num_pixels) batch_log_p_z_tilde = tf.reduce_sum(tf.log(z_likelihoods), axis=axes_except_batch) z_bpp = -batch_log_p_z_tilde / (np.log(2) * img_num_pixels) eval_bpp = y_bpp + z_bpp - bpp_back # shape (N,) train_bpp = tf.reduce_mean(eval_bpp) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. # float_train_mse = train_mse # psnr = - 10 * (tf.log(float_train_mse) / np.log(10)) # float MSE computed on float images train_mse *= 255**2 # The rate-distortion cost. if args.lmbda < 0: args.lmbda = float(args.runname.split('lmbda=')[1].split('-') [0]) # re-use the lmbda as used for training print( 'Defaulting lmbda (mse coefficient) to %g as used in model training.' % args.lmbda) if args.lmbda > 0: rd_loss = args.lmbda * train_mse + train_bpp else: rd_loss = train_bpp rd_gradients = tf.gradients(rd_loss, [y, z_mean, z_logvar]) r_gradients = tf.gradients(train_bpp, [z_mean, z_logvar]) # Bring both images back to 0..255 range, for evaluation only. x *= 255 x_tilde = tf.clip_by_value(x_tilde, 0, 1) x_tilde = tf.round(x_tilde * 255) mse = tf.reduce_mean(tf.squared_difference(x, x_tilde), axis=axes_except_batch) # shape (N,) psnr = tf.image.psnr(x_tilde, x, 255) # shape (N,) msssim = tf.image.ssim_multiscale(x_tilde, x, 255) # shape (N,) msssim_db = -10 * tf.log(1 - msssim) / np.log(10) # shape (N,) with tf.Session() as sess: # Load the latest model checkpoint, get compression stats save_dir = os.path.join(args.checkpoint_dir, args.runname) latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir) tf.train.Saver().restore(sess, save_path=latest) eval_fields = [ 'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp', 'est_z_bpp', 'est_bpp_back' ] eval_tensors = [ mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp, bpp_back ] all_results_arrs = {key: [] for key in eval_fields } # append across all batches import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt log_itv = 100 rd_lr = 0.005 # rd_opt_its = args.sga_its rd_opt_its = 2000 annealing_scheme = 'exp0' annealing_rate = args.annealing_rate # default annealing_rate = 1e-3 t0 = args.t0 # default t0 = 700 T_ub = 0.5 # max/initial temperature from utils import annealed_temperature r_lr = 0.003 r_opt_its = 2000 from adam import Adam batch_idx = 0 while True: try: x_val = sess.run(x_next) x_feed_dict = {x_ph: x_val} # 1. Perform R-D optimization conditioned on ground truth x print('----RD Optimization----') y_cur = sess.run(y_init, feed_dict=x_feed_dict) # np arrays z_mean_cur, z_logvar_cur = sess.run( [z_mean_init, z_logvar_init], feed_dict={y_tilde: y_cur}) rd_loss_hist = [] adam_optimizer = Adam(lr=rd_lr) opt_record = { 'its': [], 'T': [], 'rd_loss': [], 'rd_loss_after_rounding': [] } for it in range(rd_opt_its): temperature = annealed_temperature(it, r=annealing_rate, ub=T_ub, scheme=annealing_scheme, t0=t0) grads, obj, mse_, train_bpp_, psnr_ = sess.run( [rd_gradients, rd_loss, train_mse, train_bpp, psnr], feed_dict={ y: y_cur, z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict, T: temperature }) y_cur, z_mean_cur, z_logvar_cur = adam_optimizer.update( [y_cur, z_mean_cur, z_logvar_cur], grads) if it % log_itv == 0 or it + 1 == rd_opt_its: psnr_ = psnr_.mean() if args.verbose: bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run( [train_bpp, psnr, rd_loss], feed_dict={ y_tilde: np.round(y_cur), z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) psnr_after_rounding = psnr_after_rounding.mean() print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_, rd_loss_after_rounding, bpp_after_rounding, psnr_after_rounding)) else: print( 'it=%d, T=%.3f rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f' % (it, temperature, obj, mse_, train_bpp_, psnr_)) rd_loss_hist.append(obj) print() # 2. Fix y_tilde, perform rate optimization w.r.t. z_mean and z_logvar. y_tilde_cur = np.round( y_cur) # this is the latents we end up transmitting # rate_feed_dict = {y_tilde: y_tilde_cur, **x_feed_dict} rate_feed_dict = {y_tilde: y_tilde_cur} np.random.seed(seed) tf.set_random_seed(seed) print('----Rate Optimization----') # Reinitialize based on the value of y_tilde z_mean_cur, z_logvar_cur = sess.run( [z_mean_init, z_logvar_init], feed_dict=rate_feed_dict) # np arrays r_loss_hist = [] # rate_grad_hist = [] adam_optimizer = Adam(lr=r_lr) for it in range(r_opt_its): grads, obj = sess.run( [r_gradients, train_bpp], feed_dict={ z_mean: z_mean_cur, z_logvar: z_logvar_cur, **rate_feed_dict }) z_mean_cur, z_logvar_cur = adam_optimizer.update( [z_mean_cur, z_logvar_cur], grads) if it % log_itv == 0 or it + 1 == r_opt_its: print('it=', it, '\trate=', obj) r_loss_hist.append(obj) # rate_grad_hist.append(np.mean(np.abs(grads))) print() # fig, axes = plt.subplots(nrows=2, sharex=True) # axes[0].plot(rd_loss_hist) # axes[0].set_ylabel('RD loss') # axes[1].plot(r_loss_hist) # axes[1].set_ylabel('Rate loss') # axes[1].set_xlabel('SGD iterations') # plt.savefig('plots/local_q_opt_hist-%s-input=%s-b=%d.png' % # (args.runname, os.path.basename(args.input_file), batch_idx)) # If requested, transform the quantized image back and measure performance. eval_arrs = sess.run(eval_tensors, feed_dict={ y_tilde: y_tilde_cur, z_mean: z_mean_cur, z_logvar: z_logvar_cur, **x_feed_dict }) for field, arr in zip(eval_fields, eval_arrs): all_results_arrs[field] += arr.tolist() batch_idx += 1 except tf.errors.OutOfRangeError: break for field in eval_fields: all_results_arrs[field] = np.asarray(all_results_arrs[field]) input_file = os.path.basename(args.input_file) results_dict = all_results_arrs trained_script_name = args.runname.split('-')[0] script_name = os.path.splitext(os.path.basename(__file__))[ 0] # current script name, without extension save_file = 'rd-%s-input=%s.npz' % (args.runname, input_file) if script_name != trained_script_name: save_file = 'rd-%s-lmbda=%g+%s-input=%s.npz' % ( script_name, args.lmbda, args.runname, input_file) np.savez(os.path.join(args.results_dir, save_file), **results_dict) for field in eval_fields: arr = all_results_arrs[field] print('Avg {}: {:0.4f}'.format(field, arr.mean()))