예제 #1
0
파일: vae.py 프로젝트: baohq1595/vae-dec
    def _kld(self, z, q_param, p_param=None):
        """
        Computes the KL-divergence of
        some element z.

        KL(q||p) = -∫ q(z) log [ p(z) / q(z) ]
                 = -E[log p(z) - log q(z)]

        :param z: sample from q-distribuion
        :param q_param: (mu, log_var) of the q-distribution
        :param p_param: (mu, log_var) of the p-distribution
        :return: KL(q||p)
        """
        (q_mu, q_log_var) = q_param
        qz = log_gaussian(z, q_mu, q_log_var)

        if p_param is None:
            pz = log_standard_gaussian(z)
        else:
            (p_mu, p_log_var) = p_param
            pz = log_gaussian(z, p_mu, p_log_var)

        kl = qz - pz

        return kl
예제 #2
0
    def _kld(self, z, q_param,  i, h_last, p_param=None, sylvester_params=None, auxiliary=False):
        """
        Computes the KL-divergence of
        some element z.

        KL(q||p) = -∫ q(z) log [ p(z) / q(z) ]
                 = -E[log p(z) - log q(z)]

        :param z: sample from q-distribuion
        :param q_param: (mu, log_var) of the q-distribution
        :param p_param: (mu, log_var) of the p-distribution
        :return: KL(q||p)
        """
        if self.flow_type == "nf" and self.n_flows > 0:
            (mu, log_var) = q_param
            if not auxiliary:
                f_z, log_det_z = self.flow(z, i, False)
            else:
                f_z, log_det_z = self.flow_a(z, i, True)
            qz = log_gaussian(z, mu, log_var) - sum(log_det_z)
            z = f_z
        elif self.flow_type in ["hf", "ccLinIAF"] and self.n_flows > 0:
            (mu, log_var) = q_param
            if not auxiliary:
                f_z = self.flow(z, i, h_last, False)
            else:
                f_z = self.flow_a(z, i, h_last, True)
            qz = log_gaussian(z, mu, log_var)
            z = f_z
        elif self.flow_type in ["o-sylvester", "h-sylvester", "t-sylvester"] and self.n_flows > 0:
            mu, log_var, r1, r2, q_ortho, b = q_param
            if not auxiliary:
                f_z = self.flow(z, r1, r2, q_ortho, b, i, False)
            else:
                f_z = self.flow_a(z, r1, r2, q_ortho, b, i, True)
            qz = log_gaussian(z, mu, log_var)
            z = f_z
        else:
            (mu, log_var) = q_param
            qz = log_gaussian(z, mu, log_var)

        if p_param is None:
            pz = log_standard_gaussian(z)
        else:
            (mu, log_var) = p_param
            pz = log_gaussian(z, mu, log_var)

        kl = qz - pz
        return kl
예제 #3
0
def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss without averaging or summing over the batch dimension.
    """

    batch_size = x.size(0)

    # if not summed over batch_dimension
    if len(ldj.size()) > 1:
        ldj = ldj.view(ldj.size(0), -1).sum(-1)

    # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option
    # reduce, which when set to False, does no sum over batch dimension.
    bce = -log_bernoulli(
        x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1)
    # ln p(z_k)  (not averaged)
    log_p_zk = log_standard_gaussian(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_gaussian(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = bce + beta * (logs - ldj)

    return loss
예제 #4
0
    def calculate_losses(self,
                         data,
                         lambda1=0.,
                         lambda2=0.,
                         beta=1.,
                         likelihood=F.mse_loss):
        if self.ladder:
            ladder = "ladder"
        else:
            ladder = "not_ladder"
        self.images_path = self.results_path + "/images/examples/generative/" + ladder + "/" + self.flavour + "/"
        create_missing_folders(self.images_path)
        data = torch.tanh(data)
        if self.flow_type in ["o-sylvester", "t-sylvester", "h-sylvester"
                              ] and not self.ladder:
            z_q = {0: None, 1: None}
            reconstruction, mu, log_var, self.log_det_j, z_q[0], z_q[
                -1] = self.run_sylvester(data, auxiliary=self.auxiliary)
            log_p_zk = log_standard_gaussian(z_q[-1])
            # ln q(z_0)  (not averaged)
            # mu, log_var, r1, r2, q, b = q_param_inverse
            log_q_z0 = log_gaussian(z_q[0], mu,
                                    log_var=log_var) - self.log_det_j
            # N E_q0[ ln q(z_0) - ln p(z_k) ]
            self.kl_divergence = log_q_z0 - log_p_zk
            del log_q_z0, log_p_zk
        else:
            reconstruction, z_q = self(data)

        kl = beta * self.kl_divergence

        likelihood = torch.sum(likelihood(reconstruction,
                                          data.float(),
                                          reduce=False),
                               dim=-1)

        if self.ladder:
            params = torch.cat(
                [x.view(-1) for x in self.reconstruction.parameters()])
        else:
            params = torch.cat(
                [x.view(-1) for x in self.decoder.reconstruction.parameters()])

        l1_regularization = lambda1 * torch.norm(params, 1).cuda()
        l2_regularization = lambda2 * torch.norm(params, 2).cuda()
        try:
            assert l1_regularization >= 0. and l2_regularization >= 0.
        except:
            print(l1_regularization, l2_regularization)
        loss = torch.mean(likelihood + kl.cuda() + l1_regularization +
                          l2_regularization)

        del data, params, l1_regularization, l2_regularization, lambda1, lambda2

        return loss, torch.mean(likelihood), torch.mean(
            kl), reconstruction, z_q
예제 #5
0
def multinomial_loss_function(x_logit,
                              x,
                              z_mu,
                              z_var,
                              z_0,
                              z_k,
                              ldj,
                              args,
                              beta=1.):
    """
    Computes the cross entropy loss function while summing over batch dimension, not averaged!
    :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param args: global parameter settings
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # sums over batch dimension (and feature dimension)
    ce = cross_entropy(x_logit, target, size_average=False)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_standard_gaussian(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_gaussian(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = ce + beta * kl

    loss /= float(batch_size)
    ce /= float(batch_size)
    kl /= float(batch_size)

    return loss, ce, kl
예제 #6
0
def multinomial_loss_array(x_logit,
                           x,
                           z_mu,
                           z_var,
                           z_0,
                           z_k,
                           ldj,
                           args,
                           beta=1.):
    """
    Computes the discritezed logistic loss without averaging or summing over the batch dimension.
    """

    num_classes = 256
    batch_size = x.size(0)

    x_logit = x_logit.view(batch_size, num_classes, args.input_size[0],
                           args.input_size[1], args.input_size[2])

    # make integer class labels
    target = (x * (num_classes - 1)).long()

    # - N E_q0 [ ln p(x|z_k) ]
    # computes cross entropy over all dimensions separately:
    ce = cross_entropy(x_logit, target, size_average=False, reduce=False)
    # sum over feature dimension
    ce = ce.view(batch_size, -1).sum(dim=1)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_standard_gaussian(z_k.view(batch_size, -1), dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_gaussian(z_0.view(batch_size, -1),
                            mean=z_mu.view(batch_size, -1),
                            log_var=z_var.log().view(batch_size, -1),
                            dim=1)

    #  ln q(z_0) - ln p(z_k) ]
    logs = log_q_z0 - log_p_zk

    loss = ce + beta * (logs - ldj)

    return loss
예제 #7
0
def mse_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """
    x = torch.tanh(x)

    reconstruction_function = nn.MSELoss(size_average=False, reduce=False)

    batch_size = x.size(0)

    # - N E_q0 [ ln p(x|z_k) ]
    mse = Variable(torch.sum(reconstruction_function(recon_x, x), dim=-1))

    # ln p(z_k)  (not averaged)
    log_p_zk = log_standard_gaussian(z_k)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_gaussian(z_0, z_mu, log_var=z_var) - ldj
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    kl = abs(log_q_z0 - log_p_zk)

    # sum over batches

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    loss = mse + beta * kl
    loss = torch.sum(loss)
    mse = torch.sum(mse)
    kl = torch.sum(kl)
    loss /= float(batch_size)
    mse /= float(batch_size)
    kl /= float(batch_size)

    return loss, mse, kl
예제 #8
0
def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.):
    """
    Computes the binary loss function while summing over batch dimension, not averaged!
    :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1)
    :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1].
    :param z_mu: mean of z_0
    :param z_var: variance of z_0
    :param z_0: first stochastic latent variable
    :param z_k: last stochastic latent variable
    :param ldj: log det jacobian
    :param beta: beta for kl loss
    :return: loss, ce, kl
    """

    reconstruction_function = nn.BCELoss()
    reconstruction_function.size_average = False

    batch_size = x.size(0)

    # - N E_q0 [ ln p(x|z_k) ]
    bce = reconstruction_function(recon_x, x)

    # ln p(z_k)  (not averaged)
    log_p_zk = log_standard_gaussian(z_k, dim=1)
    # ln q(z_0)  (not averaged)
    log_q_z0 = log_gaussian(z_0, mean=z_mu, log_var=z_var.log(), dim=1)
    # N E_q0[ ln q(z_0) - ln p(z_k) ]
    summed_logs = torch.sum(log_q_z0 - log_p_zk)

    # sum over batches
    summed_ldj = torch.sum(ldj)

    # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]
    kl = (summed_logs - summed_ldj)
    loss = bce + beta * kl

    loss /= float(batch_size)
    bce /= float(batch_size)
    kl /= float(batch_size)

    return loss, bce, kl
예제 #9
0
    def run_sylvester(self,
                      x,
                      y=torch.Tensor([]).cuda(),
                      a=torch.Tensor([]).cuda(),
                      k=0,
                      auxiliary=True):
        """
        Forward pass with orthogonal sylvester flows for the transformation z_0 -> z_1 -> ... -> z_k.
        Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ].
        """
        if len(x.shape) == 2:
            x = x.view(-1, self.input_shape[0], self.input_shape[1],
                       self.input_shape[2])
        self.log_det_j = 0.
        (z_mu, z_var, r1, r2, q, b), x, z_q = self.encode(x,
                                                          y,
                                                          a,
                                                          i=k,
                                                          auxiliary=auxiliary)
        # Orthogonalize all q matrices
        if self.flow_type == "o-sylvester":
            q_ortho = self.batch_construct_orthogonal(q, auxiliary)
        elif self.flow_type == "h-sylvester":
            q_ortho = self.batch_construct_householder_orthogonal(q, auxiliary)
        else:
            q_ortho = None
        # Sample z_0

        z = [self.reparameterize(z_mu, z_var)]
        # Normalizing flows
        for i in range(self.n_flows):
            flow_k = getattr(
                self, 'flow_' + str(k) + "_" + str(i) + "_" + str(auxiliary))
            if self.flow_type in ["o-sylvester"]:
                try:
                    z_k, log_det_jacobian = flow_k(zk=z[i],
                                                   r1=r1[:, :, :, i],
                                                   r2=r2[:, :, :, i],
                                                   q_ortho=q_ortho[i, :, :, :],
                                                   b=b[:, :, :, i])
                except:

                    z_k, log_det_jacobian = flow_k(zk=z[:, i],
                                                   r1=r1[:, :, :, i],
                                                   r2=r2[:, :, :, i],
                                                   q_ortho=q_ortho[i, :, :, :],
                                                   b=b[:, :, :, i])

            elif self.flow_type in ["h-sylvester"]:
                q_k = q_ortho[i]
                z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i],
                                               r2[:, :, :, i], q_k, b[:, :, :,
                                                                      i])
            elif self.flow_type in ["t-sylvester"]:
                if k % 2 == 1:
                    # Alternate with reorderering z for triangular flow
                    permute_z = self.flip_idx
                else:
                    permute_z = None
                z_k, log_det_jacobian = flow_k(zk=z[i],
                                               r1=r1[:, :, :, i],
                                               r2=r2[:, :, :, i],
                                               b=b[:, :, :, i],
                                               permute_z=permute_z,
                                               sum_ldj=True,
                                               auxiliary=auxiliary)
            else:
                exit("Non implemented")
            z.append(z_k)
            self.log_det_j += log_det_jacobian
        log_p_zk = log_standard_gaussian(z[-1])
        # ln q(z_0)  (not averaged)
        # mu, log_var, r1, r2, q, b = q_param_inverse
        log_q_z0 = log_gaussian(z[0], z_mu, log_var=z_var) - self.log_det_j
        # N E_q0[ ln q(z_0) - ln p(z_k) ]
        self.kl_divergence = log_q_z0 - log_p_zk
        if auxiliary:
            x_mean = None
        else:
            #if len(y) == 0:
            x_mean = self.sample(z[-1], y)

        return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1]
예제 #10
0
    def forward(self,
                x,
                y=torch.Tensor([]).cuda(),
                a=torch.Tensor([]).cuda(),
                k=0,
                auxiliary=False):
        """
        Forward pass with orthogonal sylvester flows for the transformation z_0 -> z_1 -> ... -> z_k.
        Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ].
        """
        self.log_det_j = 0.

        (z_mu, z_var, r1, r2, q, b), x, z_q = self.encode(torch.cat([x, y], 1),
                                                          auxiliary=auxiliary)
        self.sylvester_params = (r1, r2, q, b)
        if self.flow_type == "o-sylvester":
            q_ortho = self.batch_construct_orthogonal(q)
        elif self.flow_type == "h-sylvester":
            q_ortho = self.batch_construct_householder_orthogonal(q)
        else:
            q_ortho = None
        # Sample z_0
        z = [self.reparameterize(z_mu, z_var)]
        # Normalizing flows
        for i in range(self.n_flows):
            flow_k = getattr(
                self, 'flow_' + str(k) + "_" + str(i) + "_" + str(auxiliary))
            if self.flow_type in ["o-sylvester"]:
                z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :,
                                                        i], r2[:, :, :, i],
                                               q_ortho[i, :, :, :], b[:, :, :,
                                                                      i])
            elif self.flow_type in ["h-sylvester"]:
                q_k = q_ortho[i]
                z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :, i],
                                               r2[:, :, :, i], q_k, b[:, :, :,
                                                                      i])
            elif self.flow_type in ["t-sylvester"]:
                if k % 2 == 1:
                    # Alternate with reorderering z for triangular flow
                    permute_z = self.flip_idx
                else:
                    permute_z = None
                z_k, log_det_jacobian = flow_k(z[i],
                                               r1[:, :, :, i],
                                               r2[:, :, :, i],
                                               b[:, :, :, i],
                                               permute_z,
                                               sum_ldj=True)
            else:
                exit("Non implemented")
            z.append(z_k)
            self.log_det_j += log_det_jacobian
        log_p_zk = log_standard_gaussian(z[-1])
        # ln q(z_0)  (not averaged)
        # mu, log_var, r1, r2, q, b = q_param_inverse
        log_q_z0 = log_gaussian(z[0], z_mu, log_var=z_var) - self.log_det_j
        # N E_q0[ ln q(z_0) - ln p(z_k) ]
        self.model.kl_divergence = log_q_z0 - log_p_zk
        x_mean, _ = self.sample(z[-1], y)

        return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1]
예제 #11
0
파일: dec.py 프로젝트: baohq1595/vae-dec
 def log_gaussians(self, x, mus, logvars):
     G = []
     for c in range(self.n_centroids):
         G.append(log_gaussian(x, mus[c:c + 1, :], logvars[c:c + 1,:]).view(-1, 1))
     
     return torch.cat(G, 1)
예제 #12
0
    def forward(self, x, i=None):
        # Gather latent representation
        # from encoders along with final z.
        latents = []
        x = torch.tanh(x)
        if self.flow_type in ["o-sylvester", "h-sylvester", "t-sylvester"]:
            for i in range(len(self.encoder)):
                q_param, x, z = self.encoder(x, i)
                latents.append(q_param)
        else:
            for i, encoder in enumerate(self.encoder):
                q_param, x = encoder(x)
                z = q_param[0]
                q_param = q_param[1:]
                latents.append(q_param)
        latents = list(reversed(latents))
        kl_divergence = 0
        h = x
        self.log_det_j = 0
        for k, decoder in enumerate([-1, *self.decoder]):
            # If at top, encode == decoder,
            # use prior for KL.
            q_param = latents[k]
            if self.sylvester_flow:
                mu, log_var, r1, r2, q, b = q_param
                if k > 0:
                    z = [self.reparameterize(mu, log_var)]
                else:
                    z = [z]
                l = -1 - k
                q_ortho = self.batch_construct_orthogonal(q, l)

                # Sample z_0
                # Normalizing flows
                for i in range(self.n_flows):
                    flow_k = getattr(self, 'flow_' + str(k) + "_" + str(i))
                    z_k, log_det_jacobian = flow_k(z[i], r1[:, :, :,
                                                            i], r2[:, :, :, i],
                                                   q_ortho[i, :, :, :],
                                                   b[:, :, :, i])

                    z.append(z_k)
                    self.log_det_j += log_det_jacobian

                # KL
                log_p_zk = log_standard_gaussian(z[-1])
                # ln q(z_0)  (not averaged)
                #mu, log_var, r1, r2, q, b = q_param_inverse
                log_q_z0 = log_gaussian(z[0], mu,
                                        log_var=log_var) - self.log_det_j
                # N E_q0[ ln q(z_0) - ln p(z_k) ]
                kl = log_q_z0 - log_p_zk
                kl_divergence += kl
                # x_mean = self.sample(z[-1])

            elif k == 0:
                kl_divergence += self._kld(z, q_param=q_param, i=k,
                                           h_last=h).abs()
            else:
                #q = (q_param_inverse[0], q_param_inverse[1])
                (mu, log_var) = q_param
                z, kl = decoder(z, mu, log_var)
                (q_z, q_param, p_param) = kl
                kl_divergence += self._kld(z,
                                           q_param=q_param,
                                           i=k,
                                           h_last=h,
                                           p_param=p_param).abs()
        try:
            x_mu = self.reconstruction(z)
        except:
            x_mu = self.reconstruction(z[-1])
        del latents, x, self.log_det_j, r1, r2, q, b, q_ortho, q_param
        self.kl_divergence = Variable(kl_divergence)
        return x_mu, z