def compute_train_kld(train_dataset, model): ### DEBUGGING KLD train_kl = [] for i, batch in enumerate(train_dataset): real_A, real_B = Variable(batch['A']), Variable(batch['B']) real_A = real_A.cuda() real_B = real_B.cuda() fake_A = model.predict_A(real_B) params = model.predict_enc_params(fake_A, real_B) mu = params[0] train_kl.append(kld_std_guss(mu, 0.0 * mu).mean(0).data[0]) if i == 100: break print 'train KL:', np.mean(train_kl)
def train_logvar(dataset, model, epochs=1, use_gpu=True): logvar_B = Variable(torch.zeros(1, 3, 64, 64).fill_(math.log(0.01)).cuda(), requires_grad=True) iterative_opt = torch.optim.RMSprop([logvar_B], lr=1e-2) for eidx in range(epochs): for batch in dataset: real_B = Variable(batch['B']) if use_gpu: real_B = real_B.cuda() size = real_B.size() dequant = Variable( torch.zeros(*real_B.size()).uniform_(0, 1. / 127.5).cuda()) real_B = real_B + dequant enc_mu = Variable(torch.zeros(size[0], model.opt.nlatent).cuda()) enc_logvar = Variable( torch.zeros(size[0], model.opt.nlatent).fill_(math.log(0.01)).cuda()) fake_A = model.predict_A(real_B) if hasattr(model, 'netE_B'): params = model.predict_enc_params(fake_A, real_B) enc_mu = Variable(params[0].data) if len(params) == 2: enc_logvar = Variable(params[1].data) z_B = gauss_reparametrize(enc_mu, enc_logvar) fake_B = model.predict_B(fake_A, z_B) z_B = z_B.view(size[0], model.opt.nlatent) log_prob = log_prob_laplace(real_B, fake_B, logvar_B) log_prob = log_prob.view(size[0], -1).sum(1) kld = kld_std_guss(enc_mu, enc_logvar) ubo = (-log_prob + kld) + (64 * 64 * 3) * math.log(127.5) ubo_val_new = ubo.mean(0).data[0] kld_val = kld.mean(0).data[0] bpp = ubo.mean(0).data[0] / (64 * 64 * 3 * math.log(2.)) print 'UBO: %.4f, KLD: %.4f, BPP: %.4f' % (ubo_val_new, kld_val, bpp) loss = ubo.mean(0) iterative_opt.zero_grad() loss.backward() iterative_opt.step() return logvar_B
def variational_ubo(model, real_A, real_B, steps, visualize=False, vis_name=None, vis_path=None, verbose=False, logvar_B=None, use_gpu=True, vis_batch=25, compute_l1=False): if visualize: assert vis_name is not None and vis_path is not None dequant = Variable( torch.zeros(*real_B.size()).uniform_(0, 1. / 127.5).cuda()) size = real_A.size() vis_size = real_A[:vis_batch].size() # define q params mu = Variable(torch.zeros(size[0], model.opt.nlatent).cuda(), requires_grad=True) logvar = Variable(torch.zeros(size[0], model.opt.nlatent).fill_( math.log(0.01)).cuda(), requires_grad=True) if logvar_B is None: logvar_B = Variable( torch.zeros(1, 3, 64, 64).fill_(math.log(0.01)).cuda()) # fake_A = model.predict_A(real_B) # init mu with encoder values if hasattr(model, 'netE_B'): params = model.predict_enc_params(real_A, real_B) enc_mu = params[0] mu = Variable(enc_mu.data, requires_grad=True) if len(params) == 2: enc_logvar = params[1] logvar = Variable(enc_logvar.data, requires_grad=True) lr = 1e-2 iterative_opt = torch.optim.RMSprop([mu, logvar], lr=1e-2) real_B = real_B + dequant rA = Variable(real_A.data, volatile=True) z_B = gauss_reparametrize(mu, logvar) fake_B = model.predict_B(real_A, z_B) if compute_l1: if model.opt.stoch_enc: rec_B = fake_B else: x = Variable(mu.view(mu.size(0), mu.size(1), 1, 1).data, volatile=True) rec_B = model.predict_B(rA, x) if visualize: if model.opt.stoch_enc: vis_z_B = z_B[:vis_batch] else: vis_z_B = mu.view(mu.size(0), mu.size(1), 1, 1)[:vis_batch] vis_B = model.predict_B(real_A[:vis_batch], vis_z_B) save_path = os.path.join(vis_path, '%s_0.png' % vis_name) visualize_data( model.opt, [real_A.data[:vis_batch], real_B.data[:vis_batch], vis_B.data], vis_size, save_path) # ubo_val = None for i in range(steps): # reshape z_B = z_B.view(size[0], model.opt.nlatent) log_prob = log_prob_laplace(real_B, fake_B, logvar_B) log_prob = log_prob.view(size[0], -1).sum(1) # log_prob_det = log_prob_gaussian_detail(real_B, fake_B, x_logvar, (size[0], n_sample, -1)) # kld = log_prob_gaussian(z_B, mu, logvar) -\ # log_prob_gaussian(z_B, 0*mu, 0*mu) # kld = kld.sum(1) kld = kld_std_guss(mu, logvar) ubo = (-log_prob + kld) + (64 * 64 * 3) * math.log(127.5) ubo_val_new = ubo.mean(0).data[0] kld_val = kld.mean(0).data[0] bpp = ubo.mean(0).data[0] / (64 * 64 * 3 * math.log(2.)) if compute_l1: l1_loss = F.l1_loss(real_B, rec_B).mean(0).data[0] if verbose: res_str = '[%d] UBO: %.4f, KLD: %.4f, BPP: %.4f' % (i, ubo_val_new, kld_val, bpp) if compute_l1: res_str = '%s, L1: %.4f' % (res_str, l1_loss) print res_str # if ubo_val is not None and abs(ubo_val - ubo_val_new) < 1e-4: # return ubo_val_new, kld_val, bpp ubo_val = ubo_val_new loss = ubo.mean(0) iterative_opt.zero_grad() loss.backward() iterative_opt.step() z_B = gauss_reparametrize(mu, logvar) fake_B = model.predict_B(real_A, z_B) if compute_l1: if model.opt.stoch_enc: rec_B = fake_B else: x = Variable(mu.view(mu.size(0), mu.size(1), 1, 1).data, volatile=True) rec_B = model.predict_B(rA, x) if visualize and (i + 1) % 100 == 0: if model.opt.stoch_enc: vis_z_B = z_B[:vis_batch] else: vis_z_B = mu.view(mu.size(0), mu.size(1), 1, 1)[:vis_batch] vis_B = model.predict_B(real_A[:vis_batch], vis_z_B) save_path = os.path.join(vis_path, '%s_%d.png' % (vis_name, i + 1)) visualize_data( model.opt, [real_A.data[:vis_batch], real_B.data[:vis_batch], vis_B.data], vis_size, save_path) # lr /= 2. # for param_group in iterative_opt.param_groups: # param_group['lr'] = lr return ubo_val, kld_val, bpp