def regression_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): reconstruction_function = nn.MSELoss() batch_size = x.size(0) # - N E_q0 [ ln p(x|z_k) ] reg_loss = reconstruction_function(recon_x, x) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = reg_loss + beta * kl loss /= float(batch_size) reg_loss /= float(batch_size) kl /= float(batch_size) return loss, reg_loss, kl
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the discritezed logistic loss without averaging or summing over the batch dimension. """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # computes cross entropy over all dimensions separately: ce = cross_entropy(x_logit, target, size_average=False, reduce=False) # sum over feature dimension ce = ce.view(batch_size, -1).sum(dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag( z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1 ) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = ce + beta * (logs - ldj) return loss
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss without averaging or summing over the batch dimension. """ batch_size = x.size(0) # if not summed over batch_dimension if len(ldj.size()) > 1: ldj = ldj.view(ldj.size(0), -1).sum(-1) # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option # reduce, which when set to False, does no sum over batch dimension. bce = -log_bernoulli( x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # ln q(z_0) - ln p(z_k) ] logs = log_q_z0 - log_p_zk loss = bce + beta * (logs - ldj) return loss
def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): """ Computes the cross entropy loss function while summing over batch dimension, not averaged! :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param args: global parameter settings :param beta: beta for kl loss :return: loss, ce, kl """ num_classes = 256 batch_size = x.size(0) x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) # make integer class labels target = (x * (num_classes - 1)).long() # - N E_q0 [ ln p(x|z_k) ] # sums over batch dimension (and feature dimension) ce = cross_entropy(x_logit, target, size_average=False) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = ce + beta * kl loss /= float(batch_size) ce /= float(batch_size) kl /= float(batch_size) return loss, ce, kl
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): """ Computes the binary loss function while summing over batch dimension, not averaged! :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. :param z_mu: mean of z_0 :param z_var: variance of z_0 :param z_0: first stochastic latent variable :param z_k: last stochastic latent variable :param ldj: log det jacobian :param beta: beta for kl loss :return: loss, ce, kl """ reconstruction_function = nn.BCELoss(size_average=False) batch_size = x.size(0) # - N E_q0 [ ln p(x|z_k) ] # print("*------------------*") # print("Target: ", x) # print("Preds: ", recon_x) bce = reconstruction_function(recon_x, x) # ln p(z_k) (not averaged) log_p_zk = log_normal_standard(z_k, dim=1) # ln q(z_0) (not averaged) log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) # N E_q0[ ln q(z_0) - ln p(z_k) ] summed_logs = torch.sum(log_q_z0 - log_p_zk) # sum over batches summed_ldj = torch.sum(ldj) # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] kl = (summed_logs - summed_ldj) loss = bce + beta * kl loss /= float(batch_size) bce /= float(batch_size) kl /= float(batch_size) return loss, bce, kl