def calculate_loss(self, x, beta=1., average=False): # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) # full loss loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': print(x.shape, x_mean.shape, x_logvar.shape) RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) # full loss loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_lower_bound(self, X_full): # CALCULATE LOWER BOUND: lower_bound = 0. RE_all = 0. KL_all = 0. MB = 100 for i in range(X_full.size(0) / MB): x = X_full[i * MB:(i + 1) * MB].view(-1, np.prod(self.args.input_size)) # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z2 = self.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) RE_all += RE.cpu().data[0] KL_all += KL.cpu().data[0] # CALCULATE LOWER-BOUND: RE + KL - ln(N) lower_bound += (-RE + KL).cpu().data[0] lower_bound = lower_bound / X_full.size(0) return lower_bound
def train_vae_vampprior_2level(epoch, args, train_loader, model, optimizer): # set loss to 0 train_loss = 0 train_re = 0 train_kl = 0 # set model in training mode model.train() # start training if args.warmup == 0: beta = 1. else: beta = 1. * (epoch - 1) / args.warmup if beta > 1.: beta = 1. print('beta: {}'.format(beta)) for batch_idx, (data, target) in enumerate(train_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) # dynamic binarization if args.dynamic_binarization: x = torch.bernoulli(data) else: x = data # reset gradients optimizer.zero_grad() # forward pass x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward( x) # loss function # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z2 = model.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = beta * (-torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)) loss = (-RE + KL) / data.size(0) # backward pass loss.backward() # optimization optimizer.step() train_loss += loss.data[0] train_re += (-RE / data.size(0)).data[0] train_kl += (KL / data.size(0)).data[0] # calculate final loss train_loss /= len( train_loader) # loss function already averages over batch size train_re /= len(train_loader) # re already averages over batch size train_kl /= len(train_loader) # kl already averages over batch size return model, train_loss, train_re, train_kl
def calculate_likelihood(self, X, dir, mode='test', S=5000): # set auxiliary variables for number of training and test sets N_test = X.size(0) # init list likelihood_test = [] MB = 100 if S <= MB: R = 1 else: R = S / MB S = MB for j in range(N_test): if j % 100 == 0: print('{:.2f}%'.format(j / (1. * N_test) * 100)) # Take x* x_single = X[j].unsqueeze(0) a = [] for r in range(0, R): # Repeat it for all training points x = x_single.expand(S, x_single.size(1)) # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # RE RE = log_Bernoulli(x, x_mean, dim=1) # KL log_p_z2 = self.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) a_tmp = (RE - KL) a.append(a_tmp.cpu().data.numpy()) # calculate max a = np.asarray(a) a = np.reshape(a, (a.shape[0] * a.shape[1], 1)) likelihood_x = logsumexp(a) likelihood_test.append(likelihood_x - np.log(S)) likelihood_test = np.array(likelihood_test) plot_histogram(-likelihood_test, dir, mode) return -np.mean(likelihood_test)
def log_p_z2(self, z2): # z1 - MB x M # X - N x D MB = z2.size(0) C = self.args.number_components M = z2.size(1) # calculate params for given data X = self.means(self.idle_input) # calculate params for given data z_p_mean, z_p_logvar = self.q_z2(X) #C x M # expand z z_expand = z2.unsqueeze(1).expand(MB, C, M) means = z_p_mean.unsqueeze(0).expand(MB, C, M) logvars = z_p_logvar.unsqueeze(0).expand(MB, C, M) a = log_Normal_diag(z_expand, means, logvars, dim=2).squeeze(2) - math.log(C) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 # calculte log-sum-exp log_p_z = a_max + torch.log( torch.sum(torch.exp(a - a_max.expand(MB, C)), 1)) # MB x 1 return log_p_z
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = - RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x = x.view(-1, np.prod(self.args.input_size)) x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) if self.isnan(z_q_mean.data[0][0]): print("mean:") print(z_q_mean) if self.isnan(z_q_logvar.data[0][0]): print("var:") print(z_q_logvar) loss = -RE + beta * KL #FI if self.args.FI is True: FI, gamma = self.FI(x) #loss -= torch.mean(FI * gamma, dim = 1) else: FI, gamma = self.FI(x) FI *= 0. #FI = (torch.mean((torch.log(2*torch.pow(torch.exp( z_q_logvar ),2) + 1) - 2 * z_q_logvar)) - self.args.M ).abs() #FI = (torch.mean((1/torch.exp( z_q_logvar ) + 1/(2*torch.pow( torch.exp( z_q_logvar ), 2 )))) - self.args.M ).abs() # MI if self.args.MI is True: MI = self.MI(x) #loss += self.args.ksi * (MI -self.args.M).abs() else: MI = self.MI(x) * 0. if self.args.adv is True: loss += self.args.ksi * (torch.exp(MI) - FI).abs() if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) FI = torch.mean(FI) MI = torch.mean(torch.exp(MI)) return loss, RE, KL, FI, MI
def log_p_z2(self, z2): # vamp prior # z2 - MB x M C = self.args.number_components # calculate params X = self.means(self.idle_input) # calculate params for given data z2_p_mean, z2_p_logvar = self.q_z2(X) # C x M # expand z z_expand = z2.unsqueeze(1) means = z2_p_mean.unsqueeze(0) logvars = z2_p_logvar.unsqueeze(0) a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log( C) # MB x C a_max, _ = torch.max(a, 1) # MB # calculte log-sum-exp log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) ) # MB return log_prior
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE RE = log_Softmax( x, x_mean, dim=1) #! Actually not Reconstruction Error but Log-Likelihood # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'multinomial': RE = log_Softmax( x, x_mean, dim=1) #! Actually not Reconstruction Error but Log-Likelihood elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL
def log_p_z2(self, z2): if self.args.prior == 'standard': log_prior = log_Normal_standard(z2, dim=1) elif self.args.prior == 'vampprior': # z - MB x M C = self.args.number_components # calculate params X = self.means(self.idle_input).view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2]) # calculate params for given data z2_p_mean, z2_p_logvar = self.q_z2(X) # C x M) # expand z z_expand = z2.unsqueeze(1) means = z2_p_mean.unsqueeze(0) logvars = z2_p_logvar.unsqueeze(0) a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C) # MB x C a_max, _ = torch.max(a, 1) # MB # calculte log-sum-exp log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))) # MB else: raise Exception('Wrong name of the prior!') return log_prior
def log_p_z(self, z): if self.args.prior == 'standard': log_prior = log_Normal_standard(z, dim=1) elif self.args.prior == 'vampprior': # z - MB x M C = self.args.number_components # calculate params X = self.means(self.idle_input) # calculate params for given data X = X.view(X.shape[0],1,self.args.input_size[1],self.args.input_size[2]) z_p_mean, z_p_logvar = self.q_z(X) # C x M # expand z z_expand = z.unsqueeze(1) means = z_p_mean.unsqueeze(0) logvars = z_p_logvar.unsqueeze(0) a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 # calculte log-sum-exp log_prior = a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 else: raise Exception('Wrong name of the prior!') return log_prior
def log_p_z2(self, z2): if self.args.prior == 'standard': log_prior = log_Normal_standard(z2, dim=1) elif self.args.prior == 'vampprior': # z2 - MB x M C = self.args.number_components # calculate params X = self.means(self.idle_input) # calculate params for given data z2_p_mean, z2_p_logvar = self.q_z2(X) # C x M # expand z z_expand = z2.unsqueeze(1) means = z2_p_mean.unsqueeze(0) logvars = z2_p_logvar.unsqueeze(0) a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C) # MB x C a_max, _ = torch.max(a, 1) # MB # calculte log-sum-exp log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))) # MB else: raise Exception('Wrong name of the prior!') return log_prior
def calculate_lower_bound(self, X_full): # CALCULATE LOWER BOUND: lower_bound = 0. RE_all = 0. KL_all = 0. MB = 500 for i in range(int(X_full.size(0) / MB)): x = X_full[i * MB:(i + 1) * MB].view(-1, np.prod(self.input_size)) # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) x_mean = x_mean.double() # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z = log_Normal_standard(z_q, dim=1) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -torch.sum(log_p_z - log_q_z) RE_all += RE.cpu().item() KL_all += KL.cpu().item() # CALCULATE LOWER-BOUND: RE + KL - ln(N) lower_bound += (-RE + KL).cpu().item() lower_bound = lower_bound / X_full.size(0) return lower_bound
def vae_re(x, x_decoded_mean): # RE part if config.data_type == 'binary': re_loss = log_Bernoulli(x, x_decoded_mean) RE = K.mean(K.sum(re_loss, axis=1), axis=0) elif config.data_type == 'gray': re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var) RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi) else: raise ValueError return -RE
def calculate_loss(self, x, y=None, beta=1., average=False, head=None): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' x_mean, x_logvar, z_q, z_q_mean, z_q_logvar, y_hat = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type in ['gray', 'color']: #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) RE = -(x - x_mean).abs().sum(dim=1) #elif self.args.input_type == 'color': # RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) # loss loss = -RE + beta * KL #+ self.args.Lambda * CE if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) if y is None: return loss, RE, KL, x_mean # CE if len(y.shape) == 1: CE = F.nll_loss(torch.log(y_hat), y) else: CE = -(y * torch.log(y_hat)).mean() # loss loss += self.args.Lambda * CE if average: CE = torch.mean(CE) return loss, RE, KL, CE, x_mean
def vae_loss(x, x_decoded_mean): # RE part if config.data_type == 'binary': re_loss = log_Bernoulli(x, x_decoded_mean) RE = K.mean(K.sum(re_loss, axis=1), axis=0) elif config.data_type == 'gray': re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var) RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi) else: raise ValueError #KL part log_q = log_Normal_diag(z_0, z_mean, z_log_var) log_p = log_Normal_standard( z[str(config.number_of_flows)] ) kl = log_q - log_p if config.regularization == 'none': KL = K.mean( K.sum( kl, axis=1 ), axis=0 ) elif config.regularization == 'warmup': KL = K.mean( K.sum( kl, axis=1 ), axis=0 ) * vae.beta else: raise Exception('wrong regularization name') return -RE + KL
def forward(self, HR_feat, LR, down_ref): z_q_mu, z_q_logvar = self.encode(HR_feat) # reparameterize z_q = self.reparameterize(z_q_mu, z_q_logvar, flag=0) # prior log_p_z = self.log_p_z(z_q, self.idle_input, self.number_component, self.prior) # KL log_q_z = log_Normal_diag(z_q, z_q_mu, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) KL = torch.sum(KL) Denoise_LR, SR = self.decode(LR, down_ref, z_q) return Denoise_LR, SR, KL
def calculate_loss(self, x, beta=1., average=False, head=0, use_mixw_cor=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE fw_head = min(head, len(self.q_z_layers) - 1) x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward( x, head=fw_head) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type in ['gray', 'color']: #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) RE = -(x - x_mean).abs().sum(dim=1) #elif self.args.input_type == 'color': # RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL if self.args.prior == 'GMM': log_p_z = self.log_p_z(z_q, mu=z_q_mean, logvar=z_q_logvar) else: log_p_z = self.log_p_z(z_q, head=head, use_mixw_cor=use_mixw_cor) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) loss = -RE + beta * KL if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) return loss, RE, KL, x_mean
def MI(self, x): x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) z_q_mean, z_q_logvar, _ = self.q_z(x_mean) z_q = self.reparameterize(z_q_mean, z_q_logvar) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, average=True, dim=1) epsilon = 1e-10 mi_loss = ( torch.mean(torch.log(torch.add(torch.exp(log_q_z), epsilon))) - self.args.M).abs() if self.isnan(mi_loss.data[0]): print(log_q_z) return mi_loss
def log_p_z(self, z, mu=None, logvar=None, head=0, use_mixw_cor=False): if self.args.prior == 'standard': log_prior = log_Normal_standard(z, dim=1) elif self.args.prior == 'vampprior': # z - MB x M C = self.args.number_components # calculate params X = self.means[head](self.idle_input) if self.args.dataset_name == 'celeba': X = X.reshape(X.size(0), 3, 64, 64) # calculate params for given data z_p_mean, z_p_logvar = self.q_z(X, head=head) # C x M # expand z z_expand = z.unsqueeze(1) means = z_p_mean.unsqueeze(0) logvars = z_p_logvar.unsqueeze(0) a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log( C) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 # calculte log-sum-exp log_prior = a_max + torch.log( torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 elif self.args.prior == 'vampprior_short': C = self.args.number_components # calculate params if self.args.separate_means: K = self.means[head].linear.weight.size(1) eye = torch.eye(K, K) if self.args.cuda: eye = eye.cuda() X = self.means[head](eye) else: K = self.means.linear.weight.size(1) eye = torch.eye(K, K) if self.args.cuda: eye = eye.cuda() X = self.means(eye) z_p_mean = self.q_z_mean(X) z_p_logvar = self.q_z_logvar(X) # expand z z_expand = z.unsqueeze(1) means = z_p_mean.unsqueeze(0) logvars = z_p_logvar.unsqueeze(0) # havent yet implemented dealing with mixing weights in the separated means case if self.args.use_vampmixingw: pis = self.mixingw(eye).t() if use_mixw_cor: if self.args.cuda: pis = pis * torch.from_numpy( self.mixingw_c).cuda().unsqueeze(0) else: pis = pis * torch.from_numpy( self.mixingw_c).unsqueeze(0) eps = 1e-30 a = log_Normal_diag(z_expand, means, logvars, dim=2) + torch.log(pis) # MB x C else: a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 # calculte log-sum-exp log_prior = a_max + torch.log( torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 elif self.args.prior == 'GMM': if self.GMM.covariance_type == 'full': #z_np = z.data.cpu().numpy() #lls = self.GMM.score_samples(z_np) #log_prior = torch.from_numpy(lls).float().cuda() + math.log(2*math.pi)*(0.5*z.size(1)) log_prior = log_mog_full(z, self.mus, self.sigs, self.pis, self.icovs_ten, self.det_terms) else: log_prior = log_mog_diag(z, self.mus, self.sigs, self.pis) else: raise Exception('invalid prior!') return log_prior
def log_p_z(self, z, mu=None, logvar=None): if self.args.prior == 'standard': log_prior = log_Normal_standard(z, dim=1) elif self.args.prior == 'vampprior': # z - MB x M C = self.args.number_components # calculate params X = self.means(self.idle_input) if self.args.dataset_name == 'celeba': X = X.reshape(X.size(0), 3, 64, 64) # calculate params for given data z_p_mean, z_p_logvar = self.q_z(X) # C x M # expand z z_expand = z.unsqueeze(1) means = z_p_mean.unsqueeze(0) logvars = z_p_logvar.unsqueeze(0) a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log( C) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 # calculte log-sum-exp log_prior = a_max + torch.log( torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 elif self.args.prior == 'vampprior_short': C = self.args.number_components K = self.means.linear.weight.size(1) eye = torch.eye(K, K) if self.args.cuda: eye = eye.cuda() X = self.means(eye) z_p_mean = self.q_z_mean(X) z_p_logvar = self.q_z_logvar(X) # expand z z_expand = z.unsqueeze(1) means = z_p_mean.unsqueeze(0) logvars = z_p_logvar.unsqueeze(0) # havent yet implemented dealing with mixing weights in the separated means case if self.args.use_vampmixingw: pis = self.mixingw(eye).t() eps = 1e-30 a = log_Normal_diag(z_expand, means, logvars, dim=2) + torch.log(pis) # MB x C else: a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C) # MB x C a_max, _ = torch.max(a, 1) # MB x 1 # calculte log-sum-exp log_prior = a_max + torch.log( torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)) # MB x 1 else: raise Exception('invalid prior!') return log_prior
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode): # set loss to 0 evaluate_loss = 0 evaluate_re = 0 evaluate_kl = 0 # set model to evaluation mode model.eval() # evaluate for batch_idx, (data, target) in enumerate(data_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) x = data # forward pass x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = model.forward(x) # loss function # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z = log_Normal_standard(z_q, dim=1) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = - torch.sum(log_p_z - log_q_z) evaluate_loss += ((-RE + KL) / data.size(0)).data[0] evaluate_re += (-RE / data.size(0)).data[0] evaluate_kl += (KL / data.size(0)).data[0] # print N digits if batch_idx == 1 and mode == 'validation': plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3) if epoch == 1: # VISUALIZATION: plot real images plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3) if mode == 'test': # load all data test_data = Variable(data_loader.dataset.data_tensor) test_target = Variable(data_loader.dataset.target_tensor) full_data = Variable(train_loader.dataset.data_tensor) if args.cuda: test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda() full_data = torch.bernoulli(full_data) # VISUALIZATION: plot real images plot_real(args, test_data, dir, size_x=5, size_y=5) # VISUALIZATION: plot reconstructions z_mean_recon, z_logvar_recon = model.q_z(test_data) z_recon = model.reparameterize(z_mean_recon, z_logvar_recon) samples, _ = model.p_x(z_recon) plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5) # VISUALIZATION: plot generations z_sample_rand = Variable(torch.normal( torch.from_numpy( np.zeros((25,args.z1_size)) ).float(), 1. ) ) if args.cuda: z_sample_rand = z_sample_rand.cuda() samples_rand, _ = model.p_x(z_sample_rand) plot_generation(args, samples_rand, dir, size_x=5, size_y=5) if args.z1_size == 2: # VISUALIZATION: plot low-dimensional manifold plot_manifold(model, args, dir) # VISUALIZATION: plot scatter-plot plot_scatter(model, test_data, test_target, dir) # CALCULATE lower-bound t_ll_s = time.time() elbo_test = model.calculate_lower_bound(test_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() elbo_train = model.calculate_lower_bound(full_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_train = 0. #model.calculate_likelihood(full_data, dir, mode='train') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # calculate final loss evaluate_loss /= len(data_loader) # loss function already averages over batch size evaluate_re /= len(data_loader) # re already averages over batch size evaluate_kl /= len(data_loader) # kl already averages over batch size if mode == 'test': return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train else: return evaluate_loss, evaluate_re, evaluate_kl
def evaluate_vae_vampprior_2level(args, model, train_loader, data_loader, epoch, dir, mode): # set loss to 0 evaluate_loss = 0 evaluate_re = 0 evaluate_kl = 0 # set model to evaluation mode model.eval() # evaluate for batch_idx, (data, target) in enumerate(data_loader): if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) x = data # forward pass x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward( x) # loss function # RE RE = log_Bernoulli(x, x_mean) # KL log_p_z2 = model.log_p_z2(z2_q) log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) KL = - torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) evaluate_loss += ((-RE + KL) / data.size(0)).data[0] evaluate_re += (-RE / data.size(0)).data[0] evaluate_kl += (KL / data.size(0)).data[0] # print N digits if batch_idx == 1 and mode == 'validation': plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3) if epoch == 1: # VISUALIZATION: plot real images plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3) if mode == 'test': # load all data test_data = Variable(data_loader.dataset.data_tensor) test_target = Variable(data_loader.dataset.target_tensor) full_data = Variable(train_loader.dataset.data_tensor) if args.cuda: test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda() full_data = torch.bernoulli(full_data) # VISUALIZATION: plot real images plot_real(args, test_data, dir, size_x=5, size_y=5) # VISUALIZATION: plot reconstructions z2_mean_recon, z2_logvar_recon = model.q_z2(test_data) z2_recon = model.reparameterize(z2_mean_recon, z2_logvar_recon) z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon) z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon) samples, _ = model.p_x(z1_recon, z2_recon) plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5) # VISUALIZATION: plot generations means = model.means(model.idle_input)[0:25] z2_sample_gen_mean, z2_sample_gen_logvar = model.q_z2(means) z2_sample_rand = model.reparameterize(z2_sample_gen_mean, z2_sample_gen_logvar) z1_mean_rand, z1_logvar_rand = model.p_z1(z2_sample_rand) z1_sample_rand = model.reparameterize(z1_mean_rand, z1_logvar_rand) samples_rand, _ = model.p_x(z1_sample_rand, z2_sample_rand) plot_generation(args, samples_rand, dir, size_x=5, size_y=5) if args.z1_size == 2 and args.z2_size == 2: # VISUALIZATION: plot low-dimensional manifold from utils.visual_evaluation import plot_manifold2 plot_manifold2(model, args, dir) # VISUALIZATION: plot scatter-plot from utils.visual_evaluation import plot_scatter2 z2_mean_recon, z2_logvar_recon = model.q_z2(test_data) z2_recon = model.reparameterize(z2_mean_recon, z2_logvar_recon) plot_scatter2(model, z2_recon, test_target, dir, name='scatter2D_z2.png') z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon) z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon) plot_scatter2(model, z1_recon, test_target, dir, name='scatter2D_z1_1.png') z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon) z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon) plot_scatter2(model, z1_recon, test_target, dir, name='scatter2D_z1_2.png') # CALCULATE lower-bound t_ll_s = time.time() elbo_test = model.calculate_lower_bound(test_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() elbo_train = model.calculate_lower_bound(full_data) t_ll_e = time.time() print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # CALCULATE log-likelihood t_ll_s = time.time() log_likelihood_train = 0. # model.calculate_likelihood(full_data, dir, mode='train') t_ll_e = time.time() print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s)) # calculate final loss evaluate_loss /= len(data_loader) # loss function already averages over batch size evaluate_re /= len(data_loader) # re already averages over batch size evaluate_kl /= len(data_loader) # kl already averages over batch size if mode == 'test': return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train else: return evaluate_loss, evaluate_re, evaluate_kl
def calculate_loss(self, x, beta=1., average=False): # pass through VAE x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward( x) # p(x|z)p(z) if self.args.input_type == 'binary': log_p_x_given_z = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': log_p_x_given_z = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_p_z = log_p_z1 + beta * log_p_z2 # q(z|x)q(x) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) log_q_z_given_x = log_q_z1 + log_q_z2 if self.args.q_x_prior == "marginal": # q(x) is marginal of p(x, z) log_q_x = log_p_x_given_z + log_p_z - log_q_z_given_x elif self.args.q_x_prior == "vampprior": # q(x) is vamprior of p(x|u) log_q_x = self.log_q_x_vampprior(x) RE = log_p_x_given_z KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2) # MIM loss loss = -0.5 * (log_p_x_given_z + log_p_z + beta * (log_q_z_given_x + log_q_x)) if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) # symmetric sampling if self.p_samp and (beta >= 1.0): z2_q = None z1_q = None # p(x|z) Sampling from PixelCNN z1_q, z2_q, x = self.generate_x( N=x.shape[0], return_z=True, z1=z1_q, z2=z2_q, ) # discrete samples should have no gradients if self.args.input_type == 'binary': x = x.detach() # discrete samples should have no gradients x_shape = (-1, ) + tuple(self.args.input_size) x = x.view(x_shape) # z2 ~ q(z2 | x) z2_q_mean, z2_q_logvar = self.q_z2(x) # z1 ~ q(z1 | x, z2) z1_q_mean, z1_q_logvar = self.q_z1(x, z2_q) # p(z1 | z2) z1_p_mean, z1_p_logvar = self.p_z1(z2_q) # x_mean = p(x|z1,z2) x_mean, x_logvar = self.p_x(z1_q, z2_q) x = x.view((x.shape[0], -1)) # p(x|z)p(z) if self.args.input_type == 'binary': log_p_x_given_z = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': log_p_x_given_z = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1) log_p_z2 = self.log_p_z2(z2_q) log_p_z = log_p_z1 + beta * log_p_z2 # q(z|x)q(x) log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1) log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1) log_q_z_given_x = log_q_z1 + log_q_z2 if self.args.q_x_prior == "marginal": # q(x) is marginal of p(x, z) log_q_x = log_p_x_given_z elif self.args.q_x_prior == "vampprior": # q(x) is vamprior of p(x|u) log_q_x = self.log_q_x_vampprior(x) loss_p = -0.5 * (log_p_x_given_z + log_p_z + beta * (log_q_z_given_x + log_q_x)) # REINFORCE if self.args.input_type == 'binary': loss_p = loss_p + loss_p.detach() * log_p_x_given_z - ( loss_p * log_p_x_given_z).detach() # MIM loss loss += beta * loss_p.mean() return loss, RE, KL
def calculate_loss(self, x, beta=1., average=False): ''' :param x: input image(s) :param beta: a hyperparam for warmup :param average: whether to average loss or not :return: value of a loss function ''' # pass through VAE x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x) # RE if self.args.input_type == 'binary': RE = log_Bernoulli(x, x_mean, dim=1) elif self.args.input_type == 'gray' or self.args.input_type == 'continuous': RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1) else: raise Exception('Wrong input type!') # KL log_p_z = self.log_p_z(z_q) log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) KL = -(log_p_z - log_q_z) if self.isnan(z_q_mean.data[0][0]): print("mean:") print(z_q_mean) if self.isnan(z_q_logvar.data[0][0]): print("var:") print(z_q_logvar) #print(z_q_logvar) #FI if self.args.FI is True: FI, gamma = self.FI(x) else: FI = Variable(torch.zeros(1), requires_grad=False) if self.args.cuda: FI = FI.cuda() #FI = (torch.mean((torch.log(2*torch.pow(torch.exp( z_q_logvar ),2) + 1) - 2 * z_q_logvar)) - self.args.M ).abs() #FI = (torch.mean((1/torch.exp( z_q_logvar ) + 1/(2*torch.pow( torch.exp( z_q_logvar ), 2 )))) - self.args.M ).abs() # MI if self.args.MI is True: MI = self.MI(x) else: MI = Variable(torch.zeros(1), requires_grad=False) if self.args.cuda: MI = MI.cuda() loss = -RE + beta * KL #+ self.args.gamma * FI + self.args.ksi * MI #- self.args.gamma * torch.log(FI) if self.args.FI is True: loss -= torch.mean(FI * gamma, dim=1) #print(FI) if average: loss = torch.mean(loss) RE = torch.mean(RE) KL = torch.mean(KL) FI = torch.mean(torch.exp(torch.mean(FI))) MI = torch.mean(MI) return loss, RE, KL, FI, MI
def vae_kl(x, x_decoded_mean): #KL part log_q = log_Normal_diag(z_0, z_mean, z_log_var) log_p = log_Normal_standard( z[str(config.number_of_flows)] ) kl_loss = log_q - log_p return K.mean( K.sum( kl_loss, axis=1 ), axis=0 )