def compute_boosted_loss(mu_g, var_g, z_g, ldj_g, mu_G, var_G, z_G, ldj_G, y, y_logits, dim_prod, args): reduction = 'mean' # Full objective - converted to bits per dimension g_nll = -1.0 * (log_normal_diag(z_g, mu_g, var_g, dim=[1, 2, 3]) + ldj_g) unconstrained_G_lhood = log_normal_diag(z_G, mu_G, var_G, dim=[1, 2, 3 ]) + ldj_G G_nll = -1.0 * torch.max(unconstrained_G_lhood, torch.ones_like(ldj_G) * G_MAX_LOSS) nll = g_nll - G_nll bpd = nll / (math.log(2.) * dim_prod) losses = {"g_nll": torch.mean(g_nll)} losses = {"G_nll": torch.mean(G_nll)} losses = {"nll": torch.mean(nll)} losses = {"bpd": torch.mean(bpd)} if args.y_condition: if args.multi_class: y_logits = torch.sigmoid(y_logits) loss_classes = F.binary_cross_entropy_with_logits( y_logits, y, reduction=reduction) else: loss_classes = F.cross_entropy(y_logits, torch.argmax(y, dim=1), reduction=reduction) losses["loss_classes"] = loss_classes losses["total_loss"] = losses["bpd"] + args.y_weight * loss_classes else: losses["total_loss"] = losses["bpd"] return losses
def boosted_neg_elbo(x_recon, x, z_mu, z_var, z_g, g_ldj, z_G, G_ldj, regularization_rate, first_component, args, beta=1.0): # Reconstruction term if args.input_type == "binary": reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum') recon_loss = reconstruction_function(x_recon, x) elif args.input_type == "multinomial": num_classes = 256 batch_size = x.size(0) if args.vae_layers == "linear": x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size)) else: x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) target = (x * (num_classes-1)).long() recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum') else: raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) # prior: ln p(z_k) (not averaged) log_p_zk = torch.sum(log_normal_standard(z_g[-1], dim=1)) # entropy loss w.r.t. to new component terms (not averaged) # N E_g[ ln g(z | x) ] (not averaged) log_g_base = log_normal_diag(z_g[0], mean=z_mu, log_var=safe_log(z_var), dim=1) log_g_z = log_g_base - g_ldj if first_component or (z_G is None and G_ldj is None): # train the first component just like a standard VAE + Normalizing Flow # or if we sampled from all components to alleviate decoder shock entropy = torch.sum(log_g_z) log_G_z = torch.zeros_like(entropy) log_ratio = torch.zeros_like(entropy).detach() else: # all other components are trained using the boosted loss # loss w.r.t. fixed component terms: log_G_base = log_normal_diag(z_G[0], mean=z_mu, log_var=safe_log(z_var), dim=1) log_G_z = torch.clamp(log_G_base - G_ldj, min=-1000.0) log_ratio = torch.sum(log_G_z.data - log_g_z.data).detach() # limit log likelihoods of FIXED components to a small number for numerical stability log_G_z = torch.sum(torch.max(log_G_z, torch.ones_like(G_ldj) * G_MAX_LOSS)) entropy = torch.sum(regularization_rate * log_g_z) loss = recon_loss + log_G_z + beta*(entropy - log_p_zk) batch_size = float(x.size(0)) loss = loss / batch_size recon_loss = recon_loss / batch_size log_G_z = log_G_z / batch_size log_p_zk = -1.0 * log_p_zk / batch_size entropy = entropy / batch_size log_ratio = log_ratio / batch_size return loss, recon_loss, log_G_z, log_p_zk, entropy, log_ratio
def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices): z_q, z_q_mean, z_q_logvar = latent_stats if exemplars_embedding is None and self.args.prior == 'exemplar_prior': exemplars_embedding = self.get_exemplar_set(z_q_mean, z_q_logvar, dataset, cache, x_indices) log_p_z = self.log_p_z(z=(z_q, x_indices), exemplars_embedding=exemplars_embedding) log_q_z = log_normal_diag(z_q, z_q_mean, z_q_logvar, dim=1) return -(log_p_z - log_q_z)
def compute_loss(z, z_mu, z_var, logdet, y, y_logits, dim_prod, args): reduction = 'mean' # Full objective - converted to bits per dimension nll = -1.0 * (log_normal_diag(z, z_mu, z_var, dim=[1, 2, 3]) + logdet) bpd = nll / (math.log(2.) * dim_prod) losses = {"bpd": torch.mean(bpd)} if args.y_condition: if args.multi_class: y_logits = torch.sigmoid(y_logits) loss_classes = F.binary_cross_entropy_with_logits( y_logits, y, reduction=reduction) else: loss_classes = F.cross_entropy(y_logits, torch.argmax(y, dim=1), reduction=reduction) losses["loss_classes"] = loss_classes losses["total_loss"] = losses["bpd"] + args.y_weight * loss_classes else: losses["total_loss"] = losses["bpd"] return losses
def binary_loss_array(x_recon, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) reconstruction_function = nn.BCEWithLogitsLoss(reduction='none') bce = reconstruction_function(x_recon.view(batch_size, -1), x.view(batch_size, -1)) # sum over feature dimension bce = bce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) if args.vae_layers == "linear": x_logit = x_logit.view(batch_size, num_classes, np.prod(args.input_size)) else: x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce_loss_function = nn.CrossEntropyLoss(reduction='none') ce = ce_loss_function(x_logit, target) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=safe_log(z_var).view(batch_size, -1), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option # reduce, which when set to False, does no sum over batch dimension. bce = -log_bernoulli( x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss
def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices): z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = latent_stats if exemplars_embedding is None and self.args.prior == 'exemplar_prior': exemplars_embedding = self.get_exemplar_set(z2_q_mean, z2_q_logvar, dataset, cache, x_indices) log_p_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size), z1_p_mean.view(-1, self.args.z1_size), z1_p_logvar.view(-1, self.args.z1_size), dim=1) log_q_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size), z1_q_mean.view(-1, self.args.z1_size), z1_q_logvar.view(-1, self.args.z1_size), dim=1) log_p_z2 = self.log_p_z(z=(z2_q, x_indices), exemplars_embedding=exemplars_embedding) log_q_z2 = log_normal_diag(z2_q.view(-1, self.args.z2_size), z2_q_mean.view(-1, self.args.z2_size), z2_q_logvar.view(-1, self.args.z2_size), dim=1) return -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)
def neg_elbo(x_recon, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.0): """ Computes the binary loss function while summing over batch dimension, not averaged! :param x_recon: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ if args.input_type == "binary": # - N E_q0 [ ln p(x|z_k) ] reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum') recon_loss = reconstruction_function(x_recon, x) elif args.input_type == "multinomial": num_classes = 256 batch_size = x.size(0) if args.vae_layers == "linear": x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size)) else: x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes-1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum') else: raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = summed_logs - summed_ldj loss = recon_loss + beta * kl batch_size = x.size(0) loss = loss / float(batch_size) recon_loss = recon_loss / float(batch_size) kl = kl / float(batch_size) return loss, recon_loss, kl
def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) ce = F.cross_entropy(x_logit, target, reduction='sum') # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = ce + beta * kl loss = loss / float(batch_size) ce = ce / float(batch_size) kl = kl / float(batch_size) return loss, ce, kl
def forward(self, sample, logdet=0.0, reverse=False, temperature=None): if reverse: z1 = sample z_mu, z_var = self.split2d_prior(z1) z2 = torch.normal(z_mu, torch.exp(z_var) * temperature) z = torch.cat((z1, z2), dim=1) return z, logdet else: z1, z2 = split_feature(sample, "split") z_mu, z_var = self.split2d_prior(z1) logdet = log_normal_diag(z2, z_mu, z_var, dim=[1, 2, 3]) + logdet return z1, logdet
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce = cross_entropy(x_logit, target, size_average=False, reduce=False) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss function while summing over batch dimension, not averaged! :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ reconstruction_function = nn.BCELoss() reconstruction_function.size_average = False batch_size = x.size(0) # - N E_q0 [ ln p(x|z_k) ] bce = reconstruction_function(recon_x, x) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = bce + beta * kl loss /= float(batch_size) bce /= float(batch_size) kl /= float(batch_size) return loss, bce, kl