Пример #1
0
def binary_loss_array(x_recon, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss without averaging or summing over the batch dimension.
    """
    batch_size = x.size(0)

    # if not summed over batch_dimension
    if len(ldj.size()) > 1:
        ldj = ldj.view(ldj.size(0), -1).sum(-1)

    reconstruction_function = nn.BCEWithLogitsLoss(reduction='none')
    bce = reconstruction_function(x_recon.view(batch_size, -1), x.view(batch_size, -1))
    # sum over feature dimension
    bce = bce.view(batch_size, -1).sum(dim=1)
    
    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = bce + beta * (logs - ldj)
    return loss
Пример #2
0
def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.):
    """
    Computes the discritezed logistic loss without averaging or summing over the batch dimension.
    """

    num_classes = 256
    batch_size = x.size(0)

    if args.vae_layers == "linear":
        x_logit = x_logit.view(batch_size, num_classes, np.prod(args.input_size))
    else:
        x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # computes cross entropy over all dimensions separately:
    ce_loss_function = nn.CrossEntropyLoss(reduction='none')
    ce = ce_loss_function(x_logit, target)
    # sum over feature dimension
    ce = ce.view(batch_size, -1).sum(dim=1)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1),
                               log_var=safe_log(z_var).view(batch_size, -1), dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = ce + beta * (logs - ldj)

    return loss
Пример #3
0
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss without averaging or summing over the batch dimension.
    """

    batch_size = x.size(0)

    # if not summed over batch_dimension
    if len(ldj.size()) > 1:
        ldj = ldj.view(ldj.size(0), -1).sum(-1)

    # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option
    # reduce, which when set to False, does no sum over batch dimension.
    bce = -log_bernoulli(
        x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1)
    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = bce + beta * (logs - ldj)

    return loss
 def _rho_gradient_g(self, x):
     """
     Estimate gradient with Monte Carlo by drawing sample zK ~ g^c and sample zK ~ G^(c-1), and
     computing their densities under the full model G^c
     """
     z_g, mu_g, var_g, ldj_g, _ = self.forward(x=x, components="c")
     g_ll = log_normal_standard(
         z_g, reduce=True, dim=-1, device=self.args.device) + ldj_g
     return g_ll.data.detach()
Пример #5
0
def boosted_neg_elbo(x_recon, x, z_mu, z_var, z_g, g_ldj, z_G, G_ldj, regularization_rate, first_component, args, beta=1.0):

    # Reconstruction term
    if args.input_type == "binary":
        reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum')
        recon_loss = reconstruction_function(x_recon, x)
    elif args.input_type == "multinomial":
        num_classes = 256
        batch_size = x.size(0)

        if args.vae_layers == "linear":
            x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size))
        else:
            x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])

        target = (x * (num_classes-1)).long()
        recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum')
    else:
        raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type)

    # prior: ln p(z_k)  (not averaged)
    log_p_zk = torch.sum(log_normal_standard(z_g[-1], dim=1))

    # entropy loss w.r.t. to new component terms (not averaged)
    # N E_g[ ln g(z | x) ]  (not averaged)
    log_g_base = log_normal_diag(z_g[0], mean=z_mu, log_var=safe_log(z_var), dim=1)
    log_g_z = log_g_base - g_ldj

    if first_component or (z_G is None and G_ldj is None):
        # train the first component just like a standard VAE + Normalizing Flow
        # or if we sampled from all components to alleviate decoder shock
        entropy = torch.sum(log_g_z)
        log_G_z = torch.zeros_like(entropy)
        log_ratio = torch.zeros_like(entropy).detach()
    else:
        # all other components are trained using the boosted loss
        # loss w.r.t. fixed component terms:
        log_G_base = log_normal_diag(z_G[0], mean=z_mu, log_var=safe_log(z_var), dim=1)
        log_G_z = torch.clamp(log_G_base - G_ldj, min=-1000.0)
        log_ratio = torch.sum(log_G_z.data - log_g_z.data).detach()

        # limit log likelihoods of FIXED components to a small number for numerical stability
        log_G_z = torch.sum(torch.max(log_G_z, torch.ones_like(G_ldj) * G_MAX_LOSS))
        entropy = torch.sum(regularization_rate * log_g_z)

    loss = recon_loss + log_G_z + beta*(entropy - log_p_zk)

    batch_size = float(x.size(0))
    loss = loss / batch_size
    recon_loss = recon_loss / batch_size
    log_G_z = log_G_z / batch_size
    log_p_zk = -1.0 * log_p_zk / batch_size
    entropy = entropy / batch_size
    log_ratio = log_ratio / batch_size

    return loss, recon_loss, log_G_z, log_p_zk, entropy, log_ratio
 def _rho_gradient_G(self, x):
     """
     Estimate gradient with Monte Carlo by drawing sample zK ~ g^c and sample zK ~ G^(c-1), and
     computing their densities under the full model G^c
     """
     fixed = "-c" if self.all_trained else "1:c-1"
     z_G, mu_G, var_G, ldj_G, _ = self.forward(x=x, components=fixed)
     G_ll = log_normal_standard(
         z_G, reduce=True, dim=-1, device=self.args.device) + ldj_G
     return G_ll.data.detach()
Пример #7
0
def neg_elbo(x_recon, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.0):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param x_recon: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """
    if args.input_type == "binary":
        # - N E_q0 [ ln p(x|z_k) ]
        reconstruction_function = nn.BCEWithLogitsLoss(reduction='sum')
        recon_loss = reconstruction_function(x_recon, x)
    elif args.input_type == "multinomial":
        num_classes = 256
        batch_size = x.size(0)
        
        if args.vae_layers == "linear":
            x_recon = x_recon.view(batch_size, num_classes, np.prod(args.input_size))
        else:
            x_recon = x_recon.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2])

        # make integer class labels
        target = (x * (num_classes-1)).long()
        
        # - N E_q0 [ ln p(x|z_k) ]
        # sums over batch dimension (and feature dimension)
        recon_loss = cross_entropy(x=x_recon, target=target, reduction='sum')
    else:
        raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=safe_log(z_var), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = summed_logs - summed_ldj
    loss = recon_loss + beta * kl

    batch_size = x.size(0)
    loss = loss / float(batch_size)
    recon_loss = recon_loss / float(batch_size)
    kl = kl / float(batch_size)

    return loss, recon_loss, kl
Пример #8
0
def multinomial_loss_function(x_logit,
                              x,
                              z_mu,
                              z_var,
                              z_0,
                              z_k,
                              ldj,
                              args,
                              beta=1.):
    """
    Computes the cross entropy loss function while summing over batch dimension, not averaged!
    :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param args: global parameter settings
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # sums over batch dimension (and feature dimension)
    ce = F.cross_entropy(x_logit, target, reduction='sum')

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = ce + beta * kl

    loss = loss / float(batch_size)
    ce = ce / float(batch_size)
    kl = kl / float(batch_size)

    return loss, ce, kl
    def _rho_gradients(self, x):
        full_ll = torch.zeros(x.size(0))
        fixed_ll = torch.zeros(x.size(0))
        new_ll = torch.zeros(x.size(0))
        for c in range(self.component + 1):
            z, _, _, ldj, _ = self.forward(x=x, components=c)
            if c == 0:
                full_ll = log_normal_standard(
                    z, reduce=True, dim=-1, device=self.args.device) + ldj

            else:
                new_ll = log_normal_standard(
                    z, reduce=True, dim=-1, device=self.args.device) + ldj
                # compute full model using recursive formula
                prev_ll = (torch.log(1 - self.rho[c]) + full_ll).view(
                    x.size(0), 1)
                next_ll = (torch.log(self.rho[c]) + new_ll).view(x.size(0), 1)
                full_ll = torch.logsumexp(torch.cat([prev_ll, next_ll], dim=1),
                                          dim=1)

            if c == self.component - 1:
                fixed_ll = full_ll

        return new_ll, fixed_ll, full_ll
Пример #10
0
def multinomial_loss_array(x_logit,
                           x,
                           z_mu,
                           z_var,
                           z_0,
                           z_k,
                           ldj,
                           args,
                           beta=1.):
    """
    Computes the discritezed logistic loss without averaging or summing over the batch dimension.
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # computes cross entropy over all dimensions separately:
    ce = cross_entropy(x_logit, target, size_average=False, reduce=False)
    # sum over feature dimension
    ce = ce.view(batch_size, -1).sum(dim=1)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0.view(batch_size, -1),
                               mean=z_mu.view(batch_size, -1),
                               log_var=z_var.log().view(batch_size, -1),
                               dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = ce + beta * (logs - ldj)

    return loss
Пример #11
0
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    reconstruction_function = nn.BCELoss()
    reconstruction_function.size_average = False

    batch_size = x.size(0)

    # - N E_q0 [ ln p(x|z_k) ]
    bce = reconstruction_function(recon_x, x)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_normal_standard(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = bce + beta * kl

    loss /= float(batch_size)
    bce /= float(batch_size)
    kl /= float(batch_size)

    return loss, bce, kl
def compute_kl_pq_loss(model, x, args):
    if args.flow == "boosted":

        if model.all_trained or model.component > 0:

            # 1. Compute likelihood/weight for each sample
            G_ll = torch.zeros(x.size(0))
            for c in range(model.component):
                z_G, _, _, ldj_G, _ = model(x=x, components=c)
                if c == 0:
                    G_ll = log_normal_standard(
                        z_G, reduce=True, dim=-1, device=args.device) + ldj_G
                else:
                    rho_simplex = model.rho[0:(c + 1)] / torch.sum(
                        model.rho[0:(c + 1)])
                    last_ll = torch.log(1 - rho_simplex[c]) + G_ll
                    next_ll = torch.log(rho_simplex[c]) + (log_normal_standard(
                        z_G, reduce=True, dim=-1, device=args.device) + ldj_G)
                    uG_ll = torch.cat([
                        last_ll.view(x.size(0), 1),
                        next_ll.view(x.size(0), 1)
                    ],
                                      dim=1)
                    G_ll = torch.logsumexp(uG_ll, dim=1)

            G_nll = -1.0 * G_ll

            # 2. Sample x with replacement, weighted by G_nll
            weights = softmax(G_nll)
            heuristic = "unity"
            if heuristic == "decay":
                beta = 1.0 / (2.0**model.component)
            elif heuristic == "uniform":
                beta = 1.0 / (1.0 + args.num_components)
            else:
                beta = 1.0  # unity

            weights = torch.pow(weights, beta)

            if weights.max() > 0.1:
                weights = torch.max(
                    torch.min(weights, torch.tensor([0.1],
                                                    device=args.device)),
                    torch.tensor([0.01], device=args.device))
            if weights.sum() != 1.0:
                weights = weights / torch.sum(weights)

            reweighted_idx = torch.multinomial(weights,
                                               x.size(0),
                                               replacement=True)
            x_resampled = x[reweighted_idx]

            # 3. Compute g for resampled observations
            z_g, _, _, ldj_g, _ = model(x=x_resampled, components="c")
            g_nll = -1.0 * (log_normal_standard(
                z_g, reduce=True, dim=-1, device=args.device) + ldj_g)
            nll = torch.mean(g_nll)

            losses = {"nll": nll}
            losses["G_nll"] = torch.mean(G_nll)
            losses["g_nll"] = torch.mean(g_nll)

        else:
            # train first boosted component just like a non-boosted model
            z_g, _, _, ldj_g, _ = model(x=x, components="c")
            g_nll = -1.0 * (log_normal_standard(
                z_g, reduce=True, dim=-1, device=args.device) + ldj_g)
            losses = {"nll": torch.mean(g_nll)}
            losses["g_nll"] = torch.mean(g_nll)
            losses["G_nll"] = torch.zeros_like(losses['g_nll'])
    else:
        z, _, _, log_det_j, _ = model(x=x)
        log_pz = log_normal_standard(z,
                                     reduce=True,
                                     dim=-1,
                                     device=args.device)
        nll = -1.0 * (log_pz + log_det_j)

        losses = {"nll": torch.mean(nll)}
        losses['log_det_jacobian'] = torch.mean(log_det_j)
        losses['log_pz'] = torch.mean(log_pz)

    if torch.isnan(losses['nll']).any():
        raise ValueError(
            f"Nan Encountered. nll={losses['nll']}, x={x}, losses={losses}")

    return losses
def evaluate(model, data_loader, args, results_type=None):
    model.eval()

    if args.boosted:
        G_nll, g_nll = [], []
        for (x, _) in data_loader:
            x = x.to(args.device)

            approximate_fixed_G = False

            if approximate_fixed_G:
                # randomly sample a component
                z_G, _, _, ldj_G, _ = model(x=x, components="1:c")
                G_nll_i = -1.0 * (log_normal_standard(
                    z_G, reduce=True, dim=-1, device=args.device) + ldj_G)
                G_nll.append(G_nll_i.detach())

            else:
                G_ll = torch.zeros(x.size(0))
                for c in range(model.component + 1):
                    z_G, _, _, ldj_G, _ = model(x=x, components=c)
                    if c == 0:
                        G_ll = log_normal_standard(
                            z_G, reduce=True, dim=-1,
                            device=args.device) + ldj_G
                    else:
                        rho_simplex = model.rho[0:(c + 1)] / torch.sum(
                            model.rho[0:(c + 1)])
                        last_ll = torch.log(1 - rho_simplex[c]) + G_ll
                        next_ll = torch.log(
                            rho_simplex[c]) + (log_normal_standard(
                                z_G, reduce=True, dim=-1, device=args.device) +
                                               ldj_G)
                        uG_ll = torch.cat([
                            last_ll.view(x.size(0), 1),
                            next_ll.view(x.size(0), 1)
                        ],
                                          dim=1)
                        G_ll = torch.logsumexp(uG_ll, dim=1)

                G_nll.append(-1.0 * G_ll.detach())

            # track new component progress just for insights
            if model.component > 0 or model.all_trained:
                z_g, mu_g, var_g, ldj_g, _ = model(x=x, components="c")
                g_nll_i = -1.0 * (log_normal_standard(
                    z_g, reduce=True, dim=-1, device=args.device) + ldj_g)
                g_nll.append(g_nll_i.detach())

        G_nll = torch.cat(G_nll, dim=0)
        mean_G_nll = G_nll.mean().item()
        losses = {'nll': mean_G_nll}

        if model.component > 0 or model.all_trained:
            g_nll = torch.cat(g_nll, dim=0)
            losses['g_nll'] = g_nll.mean().item()
            losses['ratio'] = torch.mean(g_nll - G_nll).item()
        else:
            losses['g_nll'] = mean_G_nll
            losses['ratio'] = 0.0

    else:
        nll = torch.stack([
            compute_kl_pq_loss(model, x.to(args.device), args)['nll'].detach()
            for (x, _) in data_loader
        ], -1).mean().item()
        losses = {'nll': nll, 'g_nll': nll}

    if args.save_results and results_type is not None:
        results_msg = f'{results_type} set loss: {losses["nll"]:.6f}'
        logger.info(results_msg + '\n')
        with open(args.exp_log, 'a') as ff:
            print(results_msg, file=ff)

    return losses