Ejemplo n.º 1
0
    def _kld_beta_kerman_prior(self, conc1, conc2):
        """ Internal function to do a KL-div against the prior.

        :param conc1: concentration 1.
        :param conc2: concentration 2.
        :returns: batch_size tensor of kld against prior.
        :rtype: torch.Tensor

        """
        prior = PD.Beta(zeros_like(conc1) + 1 / 3, zeros_like(conc2) + 1 / 3)
        beta = PD.Beta(conc1, conc2)
        return torch.sum(D.kl_divergence(beta, prior), -1)
Ejemplo n.º 2
0
    def _kld_gaussian_N_0_1(mu, logvar):
        """ Internal member for kl-div against a N(0, 1) prior

        :param mu: mean
        :param logvar: log-variance
        :returns: batch_size tensor of kld
        :rtype: torch.Tensor

        """
        standard_normal = D.Normal(zeros_like(mu), ones_like(logvar))
        normal = D.Normal(mu, logvar)
        return torch.sum(D.kl_divergence(normal, standard_normal), -1)
Ejemplo n.º 3
0
    def z_where_inv(z_where, clip_scale=5.0):
        # Take a batch of z_where vectors, and compute their "inverse".
        # That is, for each row compute:
        # [s,x,y] -> [1/s,-x/s,-y/s]
        # These are the parameters required to perform the inverse of the
        # spatial transform performed in the generative model.
        n = z_where.size(0)
        out = torch.cat((LocalizedSpatialTransformerFn.ng_ones(
            [1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)

        # Divide all entries by the scale. abs(scale) ensures images arent flipped
        scale = torch.max(torch.abs(z_where[:, 0:1]),
                          zeros_like(z_where[:, 0:1]) + clip_scale)
        if torch.sum(scale == 0) > 0:
            print("tensor scale of {} dim was 0!!".format(scale.shape))
            exit(-1)

        nan_check_and_break(scale, "scale")
        out = out / scale
        # out = out / z_where[:, 0:1]

        return out
Ejemplo n.º 4
0
 def _kld_gaussian_N_0_1(mu, logvar):
     standard_normal = D.Normal(zeros_like(mu), ones_like(logvar))
     normal = D.Normal(mu, logvar)
     return torch.sum(D.kl_divergence(normal, standard_normal), -1)