def calculate_lower_bound(self, X_full): # CALCULATE LOWER BOUND: lower_bound = 0. RE_all = 0. KL_all = 0. MB = 100 for i in range(X_full.size(0) / MB): x = X_full[i * MB: (i + 1) * MB].view(-1, np.prod(self.args.input_size)) # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -torch.sum(log_p_z - log_q_z) RE_all += RE.cpu().data[0] KL_all += KL.cpu().data[0] # CALCULATE LOWER-BOUND: RE + KL - ln(N) lower_bound += (-RE + KL).cpu().data[0] lower_bound = lower_bound / X_full.size(0) return lower_bound
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'multinomial': RE = log_Softmax( x, x_mean, dim=1) #! Actually not Reconstruction Error but Log-Likelihood elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = - RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': #print(x.shape, x_mean.shape, x_logvar.shape) x = x.view(x.size(0), -1) RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) # full loss loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) # full loss loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def log_q_x_vampprior(self, x): # z - MB x M C = self.args.number_components # calculate params Z1 = self.means_z1(self.idle_input_z).view(-1, self.args.z1_size) Z2 = self.means_z2(self.idle_input_z).view(-1, self.args.z2_size) # calculate params for given data q_x_mean, q_x_logvar = self.p_x(z1=Z1, z2=Z2) # C x M) # expand x x_expand = x.unsqueeze(1) means = q_x_mean.unsqueeze(0) if self.args.input_type == 'binary': a = log_Bernoulli(x_expand, means, dim=2) - math.log(C) # MB x C elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': logvars = q_x_logvar.unsqueeze(0) a = -log_Logistic_256(x_expand, means, logvars, dim=2) - math.log( C) # MB x C a_max, _ = torch.max(a, 1) # MB # calculte log-sum-exp log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) ) # MB return log_prior
def train_vae_vampprior_2level(epoch, args, train_loader, model, optimizer): # set loss to 0 train_loss = 0 train_re = 0 train_kl = 0 # set model in training mode model.train() # start training if args.warmup == 0: beta = 1. else: beta = 1. * (epoch - 1) / args.warmup if beta > 1.: beta = 1. print('beta: {}'.format(beta)) for batch_idx, (data, target) in enumerate(train_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) # dynamic binarization if args.dynamic_binarization: x = torch.bernoulli(data) else: x = data # reset gradients optimizer.zero_grad() # forward pass x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward( x) # loss function # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z2 = model.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = beta * (-torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)) loss = (-RE + KL) / data.size(0) # backward pass loss.backward() # optimization optimizer.step() train_loss += loss.data[0] train_re += (-RE / data.size(0)).data[0] train_kl += (KL / data.size(0)).data[0] # calculate final loss train_loss /= len( train_loader) # loss function already averages over batch size train_re /= len(train_loader) # re already averages over batch size train_kl /= len(train_loader) # kl already averages over batch size return model, train_loss, train_re, train_kl
def calculate_likelihood(self, X, dir, mode='test', S=5000): # set auxiliary variables for number of training and test sets N_test = X.size(0) # init list likelihood_test = [] MB = 100 if S <= MB: R = 1 else: R = S / MB S = MB for j in range(N_test): if j % 100 == 0: print('{:.2f}%'.format(j / (1. * N_test) * 100)) # Take x* x_single = X[j].unsqueeze(0) a = [] for r in range(0, R): # Repeat it for all training points x = x_single.expand(S, x_single.size(1)) # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # RE RE = log_Bernoulli(x, x_mean, dim=1) # KL log_p_z2 = self.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) a_tmp = (RE - KL) a.append(a_tmp.cpu().data.numpy()) # calculate max a = np.asarray(a) a = np.reshape(a, (a.shape[0] * a.shape[1], 1)) likelihood_x = logsumexp(a) likelihood_test.append(likelihood_x - np.log(S)) likelihood_test = np.array(likelihood_test) plot_histogram(-likelihood_test, dir, mode) return -np.mean(likelihood_test)
def vae_re(x, x_decoded_mean): # RE part if config.data_type == 'binary': re_loss = log_Bernoulli(x, x_decoded_mean) RE = K.mean(K.sum(re_loss, axis=1), axis=0) elif config.data_type == 'gray': re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var) RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi) else: raise ValueError return -RE
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) #FI if self.args.FI is True: FI, gamma = self.FI(x) else: FI, gamma = self.FI(x) FI *= 0 #FI = (torch.mean((torch.log(2*torch.pow(torch.exp( z_q_logvar ),2) + 1) - 2 * z_q_logvar)) - self.args.M ).abs() #FI = (torch.mean((1/torch.exp( z_q_logvar ) + 1/(2*torch.pow( torch.exp( z_q_logvar ), 2 )))) - self.args.M ).abs() # MI if self.args.MI is True: MI = self.MI(x) else: MI = self.MI(x) * 0 loss = -RE + beta * KL #- self.args.gamma * FI - self.args.ksi * MI #- self.args.gamma * torch.log(FI) if self.args.FI is True: loss -= torch.mean(FI * gamma, dim=1) #print(FI) if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) FI = torch.mean(torch.exp(torch.mean(FI))) MI = torch.mean(MI) return loss, RE, KL, FI, MI
def calculate_loss(self, x, y=None, beta=1., average=False, head=None): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' x_mean, x_logvar, z_q, z_q_mean, z_q_logvar, y_hat = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type in ['gray', 'color']: #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) RE = -(x - x_mean).abs().sum(dim=1) #elif self.args.input_type == 'color': # RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) # loss loss = -RE + beta * KL #+ self.args.Lambda * CE if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) if y is None: return loss, RE, KL, x_mean # CE if len(y.shape) == 1: CE = F.nll_loss(torch.log(y_hat), y) else: CE = -(y * torch.log(y_hat)).mean() # loss loss += self.args.Lambda * CE if average: CE = torch.mean(CE) return loss, RE, KL, CE, x_mean
def calculate_loss(self, x, beta=1., average=False, head=0, use_mixw_cor=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE fw_head = min(head, len(self.q_z_layers) - 1) x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward( x, head=fw_head) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type in ['gray', 'color']: #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) RE = -(x - x_mean).abs().sum(dim=1) #elif self.args.input_type == 'color': # RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL if self.args.prior == 'GMM': log_p_z = self.log_p_z(z_q, mu=z_q_mean, logvar=z_q_logvar) else: log_p_z = self.log_p_z(z_q, head=head, use_mixw_cor=use_mixw_cor) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL, x_mean
def vae_loss(x, x_decoded_mean): # RE part if config.data_type == 'binary': re_loss = log_Bernoulli(x, x_decoded_mean) RE = K.mean(K.sum(re_loss, axis=1), axis=0) elif config.data_type == 'gray': re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var) RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi) else: raise ValueError #KL part log_q = log_Normal_diag(z_0, z_mean, z_log_var) log_p = log_Normal_standard( z[str(config.number_of_flows)] ) kl = log_q - log_p if config.regularization == 'none': KL = K.mean( K.sum( kl, axis=1 ), axis=0 ) elif config.regularization == 'warmup': KL = K.mean( K.sum( kl, axis=1 ), axis=0 ) * vae.beta else: raise Exception('wrong regularization name') return -RE + KL
def reconstruction_error(self, x, x_mean, x_logvar): if self.likelihood == 'bernoulli': re = -log_Bernoulli(x, x_mean, dim=1) elif self.likelihood == 'gaussian': re = -log_Gaus_diag(x, x_mean, x_logvar, dim=1) return re
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode): # set loss to 0 evaluate_loss = 0 evaluate_re = 0 evaluate_kl = 0 # set model to evaluation mode model.eval() # evaluate for batch_idx, (data, target) in enumerate(data_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) x = data # forward pass x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = model.forward(x) # loss function # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z = log_Normal_standard(z_q, dim=1) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = - torch.sum(log_p_z - log_q_z) evaluate_loss += ((-RE + KL) / data.size(0)).data[0] evaluate_re += (-RE / data.size(0)).data[0] evaluate_kl += (KL / data.size(0)).data[0] # print N digits if batch_idx == 1 and mode == 'validation': plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3) if epoch == 1: # VISUALIZATION: plot real images plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3) if mode == 'test': # load all data test_data = Variable(data_loader.dataset.data_tensor) test_target = Variable(data_loader.dataset.target_tensor) full_data = Variable(train_loader.dataset.data_tensor) if args.cuda: test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda() full_data = torch.bernoulli(full_data) # VISUALIZATION: plot real images plot_real(args, test_data, dir, size_x=5, size_y=5) # VISUALIZATION: plot reconstructions z_mean_recon, z_logvar_recon = model.q_z(test_data) z_recon = model.reparameterize(z_mean_recon, z_logvar_recon) samples, _ = model.p_x(z_recon) plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5) # VISUALIZATION: plot generations z_sample_rand = Variable(torch.normal( torch.from_numpy( np.zeros((25,args.z1_size)) ).float(), 1. ) ) if args.cuda: z_sample_rand = z_sample_rand.cuda() samples_rand, _ = model.p_x(z_sample_rand) plot_generation(args, samples_rand, dir, size_x=5, size_y=5) if args.z1_size == 2: # VISUALIZATION: plot low-dimensional manifold plot_manifold(model, args, dir) # VISUALIZATION: plot scatter-plot plot_scatter(model, test_data, test_target, dir) # CALCULATE lower-bound t_ll_s = time.time() elbo_test = model.calculate_lower_bound(test_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() elbo_train = model.calculate_lower_bound(full_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_train = 0. #model.calculate_likelihood(full_data, dir, mode='train') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # calculate final loss evaluate_loss /= len(data_loader) # loss function already averages over batch size evaluate_re /= len(data_loader) # re already averages over batch size evaluate_kl /= len(data_loader) # kl already averages over batch size if mode == 'test': return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train else: return evaluate_loss, evaluate_re, evaluate_kl
def discriminator_loss_func(y, pred_y): return 2. * -K.mean(K.sum(log_Bernoulli(y, pred_y), axis=1), axis=0)
def evaluate_vae_vampprior_2level(args, model, train_loader, data_loader, epoch, dir, mode): # set loss to 0 evaluate_loss = 0 evaluate_re = 0 evaluate_kl = 0 # set model to evaluation mode model.eval() # evaluate for batch_idx, (data, target) in enumerate(data_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) x = data # forward pass x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward( x) # loss function # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z2 = model.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = - torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) evaluate_loss += ((-RE + KL) / data.size(0)).data[0] evaluate_re += (-RE / data.size(0)).data[0] evaluate_kl += (KL / data.size(0)).data[0] # print N digits if batch_idx == 1 and mode == 'validation': plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3) if epoch == 1: # VISUALIZATION: plot real images plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3) if mode == 'test': # load all data test_data = Variable(data_loader.dataset.data_tensor) test_target = Variable(data_loader.dataset.target_tensor) full_data = Variable(train_loader.dataset.data_tensor) if args.cuda: test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda() full_data = torch.bernoulli(full_data) # VISUALIZATION: plot real images plot_real(args, test_data, dir, size_x=5, size_y=5) # VISUALIZATION: plot reconstructions z2_mean_recon, z2_logvar_recon = model.q_z2(test_data) z2_recon = model.reparameterize(z2_mean_recon, z2_logvar_recon) z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon) z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon) samples, _ = model.p_x(z1_recon, z2_recon) plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5) # VISUALIZATION: plot generations means = model.means(model.idle_input)[0:25] z2_sample_gen_mean, z2_sample_gen_logvar = model.q_z2(means) z2_sample_rand = model.reparameterize(z2_sample_gen_mean, z2_sample_gen_logvar) z1_mean_rand, z1_logvar_rand = model.p_z1(z2_sample_rand) z1_sample_rand = model.reparameterize(z1_mean_rand, z1_logvar_rand) samples_rand, _ = model.p_x(z1_sample_rand, z2_sample_rand) plot_generation(args, samples_rand, dir, size_x=5, size_y=5) if args.z1_size == 2 and args.z2_size == 2: # VISUALIZATION: plot low-dimensional manifold from utils.visual_evaluation import plot_manifold2 plot_manifold2(model, args, dir) # VISUALIZATION: plot scatter-plot from utils.visual_evaluation import plot_scatter2 z2_mean_recon, z2_logvar_recon = model.q_z2(test_data) z2_recon = model.reparameterize(z2_mean_recon, z2_logvar_recon) plot_scatter2(model, z2_recon, test_target, dir, name='scatter2D_z2.png') z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon) z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon) plot_scatter2(model, z1_recon, test_target, dir, name='scatter2D_z1_1.png') z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon) z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon) plot_scatter2(model, z1_recon, test_target, dir, name='scatter2D_z1_2.png') # CALCULATE lower-bound t_ll_s = time.time() elbo_test = model.calculate_lower_bound(test_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() elbo_train = model.calculate_lower_bound(full_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_train = 0. # model.calculate_likelihood(full_data, dir, mode='train') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # calculate final loss evaluate_loss /= len(data_loader) # loss function already averages over batch size evaluate_re /= len(data_loader) # re already averages over batch size evaluate_kl /= len(data_loader) # kl already averages over batch size if mode == 'test': return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train else: return evaluate_loss, evaluate_re, evaluate_kl
def calculate_loss(self, x, beta=1., average=False): # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # p(x|z)p(z) if self.args.input_type == 'binary': log_p_x_given_z = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': log_p_x_given_z = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_p_z = log_p_z1 + beta * log_p_z2 # q(z|x)q(x) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) log_q_z_given_x = log_q_z1 + log_q_z2 if self.args.q_x_prior == "marginal": # q(x) is marginal of p(x, z) log_q_x = log_p_x_given_z + log_p_z - log_q_z_given_x elif self.args.q_x_prior == "vampprior": # q(x) is vamprior of p(x|u) log_q_x = self.log_q_x_vampprior(x) RE = log_p_x_given_z KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) # MIM loss loss = -0.5 * (log_p_x_given_z + log_p_z + beta * (log_q_z_given_x + log_q_x)) if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) # symmetric sampling if self.p_samp and (beta >= 1.0): z2_q = None z1_q = None # p(x|z) Sampling from PixelCNN z1_q, z2_q, x = self.generate_x( N=x.shape[0], return_z=True, z1=z1_q, z2=z2_q, ) # discrete samples should have no gradients if self.args.input_type == 'binary': x = x.detach() # discrete samples should have no gradients x_shape = (-1, ) + tuple(self.args.input_size) x = x.view(x_shape) # z2 ~ q(z2 | x) z2_q_mean, z2_q_logvar = self.q_z2(x) # z1 ~ q(z1 | x, z2) z1_q_mean, z1_q_logvar = self.q_z1(x, z2_q) # p(z1 | z2) z1_p_mean, z1_p_logvar = self.p_z1(z2_q) # x_mean = p(x|z1,z2) x_mean, x_logvar = self.p_x(z1_q, z2_q) x = x.view((x.shape[0], -1)) # p(x|z)p(z) if self.args.input_type == 'binary': log_p_x_given_z = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': log_p_x_given_z = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_p_z = log_p_z1 + beta * log_p_z2 # q(z|x)q(x) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) log_q_z_given_x = log_q_z1 + log_q_z2 if self.args.q_x_prior == "marginal": # q(x) is marginal of p(x, z) log_q_x = log_p_x_given_z elif self.args.q_x_prior == "vampprior": # q(x) is vamprior of p(x|u) log_q_x = self.log_q_x_vampprior(x) loss_p = -0.5 * (log_p_x_given_z + log_p_z + beta * (log_q_z_given_x + log_q_x)) # REINFORCE if self.args.input_type == 'binary': loss_p = loss_p + loss_p.detach() * log_p_x_given_z - ( loss_p * log_p_x_given_z).detach() # MIM loss loss += beta * loss_p.mean() return loss, RE, KL
def calculate_objective(self, X, Y): Y = Y.float() Y_prob, _, A = self.forward(X) log_likelihood = -log_Bernoulli(Y, Y_prob) return log_likelihood, A