def binary_loss_array(x_recon, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) reconstruction_function = nn.BCEWithLogitsLoss(reduction='none') bce = reconstruction_function(x_recon.view(batch_size, -1), x.view(batch_size, -1)) # sum over feature dimension bce = bce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) if args.vae_layers == "linear": x_logit = x_logit.view(batch_size, num_classes, np.prod(args.input_size)) else: x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce_loss_function = nn.CrossEntropyLoss(reduction='none') ce = ce_loss_function(x_logit, target) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=safe_log(z_var).view(batch_size, -1), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option # reduce, which when set to False, does no sum over batch dimension. bce = -log_bernoulli( x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss
def _rho_gradient_g(self, x): """ Estimate gradient with Monte Carlo by drawing sample zK ~ g^c and sample zK ~ G^(c-1), and computing their densities under the full model G^c """ z_g, mu_g, var_g, ldj_g, _ = self.forward(x=x, components="c") g_ll = log_normal_standard( z_g, reduce=True, dim=-1, device=self.args.device) + ldj_g return g_ll.data.detach()
def boosted_neg_elbo(x_recon, x, z_mu, z_var, z_g, g_ldj, z_G, G_ldj, regularization_rate, first_component, args, beta=1.0): # Reconstruction term if args.input_type == "binary": reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum') recon_loss = reconstruction_function(x_recon, x) elif args.input_type == "multinomial": num_classes = 256 batch_size = x.size(0) if args.vae_layers == "linear": x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size)) else: x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) target = (x * (num_classes-1)).long() recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum') else: raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) # prior: ln p(z_k) (not averaged) log_p_zk = torch.sum(log_normal_standard(z_g[-1], dim=1)) # entropy loss w.r.t. to new component terms (not averaged) # N E_g[ ln g(z | x) ] (not averaged) log_g_base = log_normal_diag(z_g[0], mean=z_mu, log_var=safe_log(z_var), dim=1) log_g_z = log_g_base - g_ldj if first_component or (z_G is None and G_ldj is None): # train the first component just like a standard VAE + Normalizing Flow # or if we sampled from all components to alleviate decoder shock entropy = torch.sum(log_g_z) log_G_z = torch.zeros_like(entropy) log_ratio = torch.zeros_like(entropy).detach() else: # all other components are trained using the boosted loss # loss w.r.t. fixed component terms: log_G_base = log_normal_diag(z_G[0], mean=z_mu, log_var=safe_log(z_var), dim=1) log_G_z = torch.clamp(log_G_base - G_ldj, min=-1000.0) log_ratio = torch.sum(log_G_z.data - log_g_z.data).detach() # limit log likelihoods of FIXED components to a small number for numerical stability log_G_z = torch.sum(torch.max(log_G_z, torch.ones_like(G_ldj) * G_MAX_LOSS)) entropy = torch.sum(regularization_rate * log_g_z) loss = recon_loss + log_G_z + beta*(entropy - log_p_zk) batch_size = float(x.size(0)) loss = loss / batch_size recon_loss = recon_loss / batch_size log_G_z = log_G_z / batch_size log_p_zk = -1.0 * log_p_zk / batch_size entropy = entropy / batch_size log_ratio = log_ratio / batch_size return loss, recon_loss, log_G_z, log_p_zk, entropy, log_ratio
def _rho_gradient_G(self, x): """ Estimate gradient with Monte Carlo by drawing sample zK ~ g^c and sample zK ~ G^(c-1), and computing their densities under the full model G^c """ fixed = "-c" if self.all_trained else "1:c-1" z_G, mu_G, var_G, ldj_G, _ = self.forward(x=x, components=fixed) G_ll = log_normal_standard( z_G, reduce=True, dim=-1, device=self.args.device) + ldj_G return G_ll.data.detach()
def neg_elbo(x_recon, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.0): """ Computes the binary loss function while summing over batch dimension, not averaged! :param x_recon: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ if args.input_type == "binary": # - N E_q0 [ ln p(x|z_k) ] reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum') recon_loss = reconstruction_function(x_recon, x) elif args.input_type == "multinomial": num_classes = 256 batch_size = x.size(0) if args.vae_layers == "linear": x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size)) else: x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes-1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum') else: raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = summed_logs - summed_ldj loss = recon_loss + beta * kl batch_size = x.size(0) loss = loss / float(batch_size) recon_loss = recon_loss / float(batch_size) kl = kl / float(batch_size) return loss, recon_loss, kl
def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) ce = F.cross_entropy(x_logit, target, reduction='sum') # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = ce + beta * kl loss = loss / float(batch_size) ce = ce / float(batch_size) kl = kl / float(batch_size) return loss, ce, kl
def _rho_gradients(self, x): full_ll = torch.zeros(x.size(0)) fixed_ll = torch.zeros(x.size(0)) new_ll = torch.zeros(x.size(0)) for c in range(self.component + 1): z, _, _, ldj, _ = self.forward(x=x, components=c) if c == 0: full_ll = log_normal_standard( z, reduce=True, dim=-1, device=self.args.device) + ldj else: new_ll = log_normal_standard( z, reduce=True, dim=-1, device=self.args.device) + ldj # compute full model using recursive formula prev_ll = (torch.log(1 - self.rho[c]) + full_ll).view( x.size(0), 1) next_ll = (torch.log(self.rho[c]) + new_ll).view(x.size(0), 1) full_ll = torch.logsumexp(torch.cat([prev_ll, next_ll], dim=1), dim=1) if c == self.component - 1: fixed_ll = full_ll return new_ll, fixed_ll, full_ll
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce = cross_entropy(x_logit, target, size_average=False, reduce=False) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss function while summing over batch dimension, not averaged! :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ reconstruction_function = nn.BCELoss() reconstruction_function.size_average = False batch_size = x.size(0) # - N E_q0 [ ln p(x|z_k) ] bce = reconstruction_function(recon_x, x) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = bce + beta * kl loss /= float(batch_size) bce /= float(batch_size) kl /= float(batch_size) return loss, bce, kl
def compute_kl_pq_loss(model, x, args): if args.flow == "boosted": if model.all_trained or model.component > 0: # 1. Compute likelihood/weight for each sample G_ll = torch.zeros(x.size(0)) for c in range(model.component): z_G, _, _, ldj_G, _ = model(x=x, components=c) if c == 0: G_ll = log_normal_standard( z_G, reduce=True, dim=-1, device=args.device) + ldj_G else: rho_simplex = model.rho[0:(c + 1)] / torch.sum( model.rho[0:(c + 1)]) last_ll = torch.log(1 - rho_simplex[c]) + G_ll next_ll = torch.log(rho_simplex[c]) + (log_normal_standard( z_G, reduce=True, dim=-1, device=args.device) + ldj_G) uG_ll = torch.cat([ last_ll.view(x.size(0), 1), next_ll.view(x.size(0), 1) ], dim=1) G_ll = torch.logsumexp(uG_ll, dim=1) G_nll = -1.0 * G_ll # 2. Sample x with replacement, weighted by G_nll weights = softmax(G_nll) heuristic = "unity" if heuristic == "decay": beta = 1.0 / (2.0**model.component) elif heuristic == "uniform": beta = 1.0 / (1.0 + args.num_components) else: beta = 1.0 # unity weights = torch.pow(weights, beta) if weights.max() > 0.1: weights = torch.max( torch.min(weights, torch.tensor([0.1], device=args.device)), torch.tensor([0.01], device=args.device)) if weights.sum() != 1.0: weights = weights / torch.sum(weights) reweighted_idx = torch.multinomial(weights, x.size(0), replacement=True) x_resampled = x[reweighted_idx] # 3. Compute g for resampled observations z_g, _, _, ldj_g, _ = model(x=x_resampled, components="c") g_nll = -1.0 * (log_normal_standard( z_g, reduce=True, dim=-1, device=args.device) + ldj_g) nll = torch.mean(g_nll) losses = {"nll": nll} losses["G_nll"] = torch.mean(G_nll) losses["g_nll"] = torch.mean(g_nll) else: # train first boosted component just like a non-boosted model z_g, _, _, ldj_g, _ = model(x=x, components="c") g_nll = -1.0 * (log_normal_standard( z_g, reduce=True, dim=-1, device=args.device) + ldj_g) losses = {"nll": torch.mean(g_nll)} losses["g_nll"] = torch.mean(g_nll) losses["G_nll"] = torch.zeros_like(losses['g_nll']) else: z, _, _, log_det_j, _ = model(x=x) log_pz = log_normal_standard(z, reduce=True, dim=-1, device=args.device) nll = -1.0 * (log_pz + log_det_j) losses = {"nll": torch.mean(nll)} losses['log_det_jacobian'] = torch.mean(log_det_j) losses['log_pz'] = torch.mean(log_pz) if torch.isnan(losses['nll']).any(): raise ValueError( f"Nan Encountered. nll={losses['nll']}, x={x}, losses={losses}") return losses
def evaluate(model, data_loader, args, results_type=None): model.eval() if args.boosted: G_nll, g_nll = [], [] for (x, _) in data_loader: x = x.to(args.device) approximate_fixed_G = False if approximate_fixed_G: # randomly sample a component z_G, _, _, ldj_G, _ = model(x=x, components="1:c") G_nll_i = -1.0 * (log_normal_standard( z_G, reduce=True, dim=-1, device=args.device) + ldj_G) G_nll.append(G_nll_i.detach()) else: G_ll = torch.zeros(x.size(0)) for c in range(model.component + 1): z_G, _, _, ldj_G, _ = model(x=x, components=c) if c == 0: G_ll = log_normal_standard( z_G, reduce=True, dim=-1, device=args.device) + ldj_G else: rho_simplex = model.rho[0:(c + 1)] / torch.sum( model.rho[0:(c + 1)]) last_ll = torch.log(1 - rho_simplex[c]) + G_ll next_ll = torch.log( rho_simplex[c]) + (log_normal_standard( z_G, reduce=True, dim=-1, device=args.device) + ldj_G) uG_ll = torch.cat([ last_ll.view(x.size(0), 1), next_ll.view(x.size(0), 1) ], dim=1) G_ll = torch.logsumexp(uG_ll, dim=1) G_nll.append(-1.0 * G_ll.detach()) # track new component progress just for insights if model.component > 0 or model.all_trained: z_g, mu_g, var_g, ldj_g, _ = model(x=x, components="c") g_nll_i = -1.0 * (log_normal_standard( z_g, reduce=True, dim=-1, device=args.device) + ldj_g) g_nll.append(g_nll_i.detach()) G_nll = torch.cat(G_nll, dim=0) mean_G_nll = G_nll.mean().item() losses = {'nll': mean_G_nll} if model.component > 0 or model.all_trained: g_nll = torch.cat(g_nll, dim=0) losses['g_nll'] = g_nll.mean().item() losses['ratio'] = torch.mean(g_nll - G_nll).item() else: losses['g_nll'] = mean_G_nll losses['ratio'] = 0.0 else: nll = torch.stack([ compute_kl_pq_loss(model, x.to(args.device), args)['nll'].detach() for (x, _) in data_loader ], -1).mean().item() losses = {'nll': nll, 'g_nll': nll} if args.save_results and results_type is not None: results_msg = f'{results_type} set loss: {losses["nll"]:.6f}' logger.info(results_msg + '\n') with open(args.exp_log, 'a') as ff: print(results_msg, file=ff) return losses