Exemple #1
0
    def loss(self, outputs, weight):
        c_outputs, z_outputs, x_outputs = outputs

        # 1. Reconstruction loss
        x, x_mean = x_outputs
        recon_loss = bernoulli_log_likelihood(x.view(-1, 1, 28, 28), x_mean)
        recon_loss /= (self.batch_size * self.sample_size)

        # 2. KL Divergence terms
        kl = 0

        # a) Context divergence
        c_mean, c_logvar = c_outputs
        kl_c = kl_diagnormal_stdnormal(c_mean, c_logvar)
        kl += kl_c

        # b) Latent divergences
        qz_params, pz_params = z_outputs
        shapes = ((self.batch_size, self.sample_size, self.z_dim),
                  (self.batch_size, 1, self.z_dim))
        for i in range(self.n_stochastic):
            args = (qz_params[i][0].view(shapes[0]),
                    qz_params[i][1].view(shapes[0]),
                    pz_params[i][0].view(shapes[1] if i == 0 else shapes[0]),
                    pz_params[i][1].view(shapes[1] if i == 0 else shapes[0]))
            kl_z = kl_diagnormal_diagnormal(*args)
            kl += kl_z

        kl /= (self.batch_size * self.sample_size)

        # Variational lower bound and weighted loss
        vlb = recon_loss - kl
        loss = -((weight * recon_loss) - (kl / weight))

        return loss, vlb
Exemple #2
0
    def summarize(self, dataset, output_size=6):
        """
        There's some nasty indexing going on here because pytorch doesn't
        have numpy indexing yet. This will be fixed soon.
        """
        # cast to torch Cuda Variable and reshape
        dataset = dataset.view(1, self.sample_size, self.n_features)
        # get approximate posterior over full dataset
        c_mean_full, c_logvar_full = self.statistic_network(dataset,
                                                            summarize=True)

        # iteratively discard until dataset is of required size
        while dataset.size(1) != output_size:
            kl_divergences = []
            # need KL divergence between full approximate posterior and all
            # subsets of given size
            subset_indices = list(
                combinations(range(dataset.size(1)),
                             dataset.size(1) - 1))

            for subset_index in subset_indices:
                # pull out subset, numpy indexing will make this much easier
                ix = Variable(torch.LongTensor(subset_index).cuda())
                subset = dataset.index_select(1, ix)

                # calculate approximate posterior over subset
                c_mean, c_logvar = self.statistic_network(subset,
                                                          summarize=True)
                kl = kl_diagnormal_diagnormal(c_mean_full, c_logvar_full,
                                              c_mean, c_logvar)
                kl_divergences.append(kl.data[0])

            # determine which sample we want to remove
            best_index = kl_divergences.index(min(kl_divergences))

            # determine which samples to keep
            to_keep = subset_indices[best_index]
            to_keep = Variable(torch.LongTensor(to_keep).cuda())

            # keep only desired samples
            dataset = dataset.index_select(1, to_keep)

        # return pruned dataset
        return dataset