def compute_boosted_loss(mu_g, var_g, z_g, ldj_g, mu_G, var_G, z_G, ldj_G, y,
                         y_logits, dim_prod, args):
    reduction = 'mean'

    # Full objective - converted to bits per dimension
    g_nll = -1.0 * (log_normal_diag(z_g, mu_g, var_g, dim=[1, 2, 3]) + ldj_g)
    unconstrained_G_lhood = log_normal_diag(z_G, mu_G, var_G, dim=[1, 2, 3
                                                                   ]) + ldj_G
    G_nll = -1.0 * torch.max(unconstrained_G_lhood,
                             torch.ones_like(ldj_G) * G_MAX_LOSS)

    nll = g_nll - G_nll
    bpd = nll / (math.log(2.) * dim_prod)
    losses = {"g_nll": torch.mean(g_nll)}
    losses = {"G_nll": torch.mean(G_nll)}
    losses = {"nll": torch.mean(nll)}
    losses = {"bpd": torch.mean(bpd)}

    if args.y_condition:
        if args.multi_class:
            y_logits = torch.sigmoid(y_logits)
            loss_classes = F.binary_cross_entropy_with_logits(
                y_logits, y, reduction=reduction)
        else:
            loss_classes = F.cross_entropy(y_logits,
                                           torch.argmax(y, dim=1),
                                           reduction=reduction)

        losses["loss_classes"] = loss_classes
        losses["total_loss"] = losses["bpd"] + args.y_weight * loss_classes

    else:
        losses["total_loss"] = losses["bpd"]

    return losses
Esempio n. 2
0
def boosted_neg_elbo(x_recon, x, z_mu, z_var, z_g, g_ldj, z_G, G_ldj, regularization_rate, first_component, args, beta=1.0):

    # Reconstruction term
    if args.input_type == "binary":
        reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum')
        recon_loss = reconstruction_function(x_recon, x)
    elif args.input_type == "multinomial":
        num_classes = 256
        batch_size = x.size(0)

        if args.vae_layers == "linear":
            x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size))
        else:
            x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])

        target = (x * (num_classes-1)).long()
        recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum')
    else:
        raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type)

    # prior: ln p(z_k)  (not averaged)
    log_p_zk = torch.sum(log_normal_standard(z_g[-1], dim=1))

    # entropy loss w.r.t. to new component terms (not averaged)
    # N E_g[ ln g(z | x) ]  (not averaged)
    log_g_base = log_normal_diag(z_g[0], mean=z_mu, log_var=safe_log(z_var), dim=1)
    log_g_z = log_g_base - g_ldj

    if first_component or (z_G is None and G_ldj is None):
        # train the first component just like a standard VAE + Normalizing Flow
        # or if we sampled from all components to alleviate decoder shock
        entropy = torch.sum(log_g_z)
        log_G_z = torch.zeros_like(entropy)
        log_ratio = torch.zeros_like(entropy).detach()
    else:
        # all other components are trained using the boosted loss
        # loss w.r.t. fixed component terms:
        log_G_base = log_normal_diag(z_G[0], mean=z_mu, log_var=safe_log(z_var), dim=1)
        log_G_z = torch.clamp(log_G_base - G_ldj, min=-1000.0)
        log_ratio = torch.sum(log_G_z.data - log_g_z.data).detach()

        # limit log likelihoods of FIXED components to a small number for numerical stability
        log_G_z = torch.sum(torch.max(log_G_z, torch.ones_like(G_ldj) * G_MAX_LOSS))
        entropy = torch.sum(regularization_rate * log_g_z)

    loss = recon_loss + log_G_z + beta*(entropy - log_p_zk)

    batch_size = float(x.size(0))
    loss = loss / batch_size
    recon_loss = recon_loss / batch_size
    log_G_z = log_G_z / batch_size
    log_p_zk = -1.0 * log_p_zk / batch_size
    entropy = entropy / batch_size
    log_ratio = log_ratio / batch_size

    return loss, recon_loss, log_G_z, log_p_zk, entropy, log_ratio
Esempio n. 3
0
 def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices):
     z_q, z_q_mean, z_q_logvar = latent_stats
     if exemplars_embedding is None and self.args.prior == 'exemplar_prior':
         exemplars_embedding = self.get_exemplar_set(z_q_mean, z_q_logvar, dataset, cache, x_indices)
     log_p_z = self.log_p_z(z=(z_q, x_indices), exemplars_embedding=exemplars_embedding)
     log_q_z = log_normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
     return -(log_p_z - log_q_z)
def compute_loss(z, z_mu, z_var, logdet, y, y_logits, dim_prod, args):
    reduction = 'mean'

    # Full objective - converted to bits per dimension
    nll = -1.0 * (log_normal_diag(z, z_mu, z_var, dim=[1, 2, 3]) + logdet)
    bpd = nll / (math.log(2.) * dim_prod)
    losses = {"bpd": torch.mean(bpd)}

    if args.y_condition:
        if args.multi_class:
            y_logits = torch.sigmoid(y_logits)
            loss_classes = F.binary_cross_entropy_with_logits(
                y_logits, y, reduction=reduction)
        else:
            loss_classes = F.cross_entropy(y_logits,
                                           torch.argmax(y, dim=1),
                                           reduction=reduction)

        losses["loss_classes"] = loss_classes
        losses["total_loss"] = losses["bpd"] + args.y_weight * loss_classes

    else:
        losses["total_loss"] = losses["bpd"]

    return losses
Esempio n. 5
0
def binary_loss_array(x_recon, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss without averaging or summing over the batch dimension.
    """
    batch_size = x.size(0)

    # if not summed over batch_dimension
    if len(ldj.size()) > 1:
        ldj = ldj.view(ldj.size(0), -1).sum(-1)

    reconstruction_function = nn.BCEWithLogitsLoss(reduction='none')
    bce = reconstruction_function(x_recon.view(batch_size, -1), x.view(batch_size, -1))
    # sum over feature dimension
    bce = bce.view(batch_size, -1).sum(dim=1)
    
    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = bce + beta * (logs - ldj)
    return loss
Esempio n. 6
0
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.):
    """
    Computes the discritezed logistic loss without averaging or summing over the batch dimension.
    """

    num_classes = 256
    batch_size = x.size(0)

    if args.vae_layers == "linear":
        x_logit = x_logit.view(batch_size, num_classes, np.prod(args.input_size))
    else:
        x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # computes cross entropy over all dimensions separately:
    ce_loss_function = nn.CrossEntropyLoss(reduction='none')
    ce = ce_loss_function(x_logit, target)
    # sum over feature dimension
    ce = ce.view(batch_size, -1).sum(dim=1)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1),
                               log_var=safe_log(z_var).view(batch_size, -1), dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = ce + beta * (logs - ldj)

    return loss
Esempio n. 7
0
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss without averaging or summing over the batch dimension.
    """

    batch_size = x.size(0)

    # if not summed over batch_dimension
    if len(ldj.size()) > 1:
        ldj = ldj.view(ldj.size(0), -1).sum(-1)

    # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option
    # reduce, which when set to False, does no sum over batch dimension.
    bce = -log_bernoulli(
        x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1)
    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = bce + beta * (logs - ldj)

    return loss
Esempio n. 8
0
 def kl_loss(self, latent_stats, exemplars_embedding, dataset, cache, x_indices):
     z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = latent_stats
     if exemplars_embedding is None and self.args.prior == 'exemplar_prior':
         exemplars_embedding = self.get_exemplar_set(z2_q_mean, z2_q_logvar,
                                                     dataset, cache, x_indices)
     log_p_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size),
                                z1_p_mean.view(-1, self.args.z1_size),
                                z1_p_logvar.view(-1, self.args.z1_size), dim=1)
     log_q_z1 = log_normal_diag(z1_q.view(-1, self.args.z1_size),
                                z1_q_mean.view(-1, self.args.z1_size),
                                z1_q_logvar.view(-1, self.args.z1_size), dim=1)
     log_p_z2 = self.log_p_z(z=(z2_q, x_indices),
                             exemplars_embedding=exemplars_embedding)
     log_q_z2 = log_normal_diag(z2_q.view(-1, self.args.z2_size),
                                z2_q_mean.view(-1, self.args.z2_size),
                                z2_q_logvar.view(-1, self.args.z2_size), dim=1)
     return -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)
Esempio n. 9
0
def neg_elbo(x_recon, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.0):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param x_recon: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """
    if args.input_type == "binary":
        # - N E_q0 [ ln p(x|z_k) ]
        reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum')
        recon_loss = reconstruction_function(x_recon, x)
    elif args.input_type == "multinomial":
        num_classes = 256
        batch_size = x.size(0)
        
        if args.vae_layers == "linear":
            x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size))
        else:
            x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])

        # make integer class labels
        target = (x * (num_classes-1)).long()
        
        # - N E_q0 [ ln p(x|z_k) ]
        # sums over batch dimension (and feature dimension)
        recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum')
    else:
        raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = summed_logs - summed_ldj
    loss = recon_loss + beta * kl

    batch_size = x.size(0)
    loss = loss / float(batch_size)
    recon_loss = recon_loss / float(batch_size)
    kl = kl / float(batch_size)

    return loss, recon_loss, kl
Esempio n. 10
0
def multinomial_loss_function(x_logit,
                              x,
                              z_mu,
                              z_var,
                              z_0,
                              z_k,
                              ldj,
                              args,
                              beta=1.):
    """
    Computes the cross entropy loss function while summing over batch dimension, not averaged!
    :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param args: global parameter settings
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # sums over batch dimension (and feature dimension)
    ce = F.cross_entropy(x_logit, target, reduction='sum')

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = ce + beta * kl

    loss = loss / float(batch_size)
    ce = ce / float(batch_size)
    kl = kl / float(batch_size)

    return loss, ce, kl
Esempio n. 11
0
 def forward(self, sample, logdet=0.0, reverse=False, temperature=None):
     if reverse:
         z1 = sample
         z_mu, z_var = self.split2d_prior(z1)
         z2 = torch.normal(z_mu, torch.exp(z_var) * temperature)
         z = torch.cat((z1, z2), dim=1)
         return z, logdet
     else:
         z1, z2 = split_feature(sample, "split")
         z_mu, z_var = self.split2d_prior(z1)
         logdet = log_normal_diag(z2, z_mu, z_var, dim=[1, 2, 3]) + logdet
         return z1, logdet
Esempio n. 12
0
def multinomial_loss_array(x_logit,
                           x,
                           z_mu,
                           z_var,
                           z_0,
                           z_k,
                           ldj,
                           args,
                           beta=1.):
    """
    Computes the discritezed logistic loss without averaging or summing over the batch dimension.
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # computes cross entropy over all dimensions separately:
    ce = cross_entropy(x_logit, target, size_average=False, reduce=False)
    # sum over feature dimension
    ce = ce.view(batch_size, -1).sum(dim=1)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0.view(batch_size, -1),
                               mean=z_mu.view(batch_size, -1),
                               log_var=z_var.log().view(batch_size, -1),
                               dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = ce + beta * (logs - ldj)

    return loss
Esempio n. 13
0
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    reconstruction_function = nn.BCELoss()
    reconstruction_function.size_average = False

    batch_size = x.size(0)

    # - N E_q0 [ ln p(x|z_k) ]
    bce = reconstruction_function(recon_x, x)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = bce + beta * kl

    loss /= float(batch_size)
    bce /= float(batch_size)
    kl /= float(batch_size)

    return loss, bce, kl