예제 #1
0
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss without averaging or summing over the batch dimension.
    """

    batch_size = x.size(0)

    # if not summed over batch_dimension
    if len(ldj.size()) > 1:
        ldj = ldj.view(ldj.size(0), -1).sum(-1)

    # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option
    # reduce, which when set to False, does no sum over batch dimension.
    bce = -log_bernoulli(
        x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1)
    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = bce + beta * (logs - ldj)

    return loss
예제 #2
0
def multinomial_loss_function(x_logit,
                              x,
                              z_mu,
                              z_var,
                              z_0,
                              z_k,
                              ldj,
                              args,
                              beta=1.):
    """
    Computes the cross entropy loss function while summing over batch dimension, not averaged!
    :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param args: global parameter settings
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    num_classes = 256
    batch_size = x.size(0)
    input_size = [3, 28, 28]
    # x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])
    x_logit = x_logit.view(batch_size, num_classes, input_size[0],
                           input_size[1], input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # sums over batch dimension (and feature dimension)
    ce = F.cross_entropy(x_logit, target, reduction='sum')

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = ce + beta * kl

    loss = loss / float(batch_size)
    ce = ce / float(batch_size)
    kl = kl / float(batch_size)

    return loss, ce, kl
예제 #3
0
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    reconstruction_function = nn.BCELoss(reduction='sum')
    # print('mu: ', z_mu.mean().item(), 'logvar: ', z_var.mean().item())

    batch_size = x.size(0)

    # - N E_q0 [ ln p(x|z_k) ]
    bce = reconstruction_function(recon_x, x)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    # print('z0: ', z_0.mean().item(), 'z_k: ', z_k.mean().item())
    # log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_logvar, dim=1)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)

    loss = bce + beta * kl

    loss = loss / float(batch_size)
    bce = bce / float(batch_size)
    kl = kl / float(batch_size)

    return loss, bce, kl
예제 #4
0
def multinomial_loss_array(x_logit,
                           x,
                           z_mu,
                           z_var,
                           z_0,
                           z_k,
                           ldj,
                           args,
                           beta=1.):
    """
    Computes the discritezed logistic loss without averaging or summing over the batch dimension.
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # computes cross entropy over all dimensions separately:
    ce = F.cross_entropy(x_logit, target, reduction='none')
    # sum over feature dimension
    ce = ce.view(batch_size, -1).sum(dim=1)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0.view(batch_size, -1),
                               mean=z_mu.view(batch_size, -1),
                               log_var=z_var.log().view(batch_size, -1),
                               dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = ce + beta * (logs - ldj)

    return loss
예제 #5
0
                                      args)

    else:
        raise ValueError('Invalid input type for calculate loss: %s.' %
                         args.input_type)

    return loss


if __name__ == '__main__':
    ldj = None
    z_0 = z_k = z_mu = torch.ones((256, 64)) * 0.001
    z_log_var = torch.ones((256, 64)) * 0.0001
    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    # print('z0: ', z_0.mean().item(), 'z_k: ', z_k.mean().item())
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_log_var, dim=1)
    # print('log_q_z0', log_q_z0.mean().item(), 'log_p_zk: ', log_p_zk.mean().item())
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    difference = log_q_z0 - log_p_zk
    # print('mean difference: ', difference.mean().item())
    summed_logs = torch.sum(difference)
    # print('summed_logs: ', summed_logs.item())
    # sum over batches
    summed_ldj = torch.sum(ldj) if ldj else 0.

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    # print('KLD: ', kl.mean().item())