def elbo_loss(recon_image, image, recon_text, text, z_mu, z_var, z_0, z_k, ldj, args, lambda_image=1.0, lambda_text=1.0, annealing_factor=1.0, beta=1.): """Bimodal ELBO loss function. """ # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var, dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] #summed_logs = torch.sum(log_q_z0 - log_p_zk) logs = log_q_z0 - log_p_zk # sum over batches #summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = logs.sub(ldj).to(torch.double) image_bce, text_bce = 0.0, 0.0 # default params if recon_image is not None and image is not None: image_bce = torch.sum(binary_cross_entropy_with_logits( recon_image.view(-1, 1 * 28 * 28), image.view(-1, 1 * 28 * 28)), dim=1, dtype=torch.double) if recon_text is not None and text is not None: text_bce = torch.sum(cross_entropy(recon_text, text), dim=1, dtype=torch.double) # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 ELBO = torch.mean(lambda_image * image_bce + lambda_text * text_bce + annealing_factor * kl) return ELBO, image_bce, text_bce, kl
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce = cross_entropy(x_logit, target, size_average=False, reduce=False) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag( z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1 ) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def binary_loss_function(recon_x1, x1, recon_x2, x2, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss function while summing over batch dimension, not averaged! :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ reconstruction_function = nn.BCELoss(size_average=False) if x1 is not None: batch_size = x1.size(0) if x2 is not None: batch_size = x2.size(0) # - N E_q0 [ ln p(x|z_k) ] bce1 = 0 bce2 = 0 if recon_x1 is not None and x1 is not None: bce1 = reconstruction_function(recon_x1, x1) if recon_x2 is not None and x2 is not None: bce2 = reconstruction_function(recon_x2, x2) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = bce1 + bce2 + beta * kl loss = loss / float(batch_size) bce1 = bce1 / float(batch_size) bce2 = bce2 / float(batch_size) kl = kl / float(batch_size) return loss, bce1, bce2, kl
def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) ce = cross_entropy(x_logit, target, size_average=False) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = ce + beta * kl loss = loss / float(batch_size) ce = loss / float(batch_size) kl = kl / float(batch_size) return loss, ce, kl
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option # reduce, which when set to False, does no sum over batch dimension. bce = -log_bernoulli(x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss