def _kld_beta_kerman_prior(self, conc1, conc2): """ Internal function to do a KL-div against the prior. :param conc1: concentration 1. :param conc2: concentration 2. :returns: batch_size tensor of kld against prior. :rtype: torch.Tensor """ prior = PD.Beta(zeros_like(conc1) + 1 / 3, zeros_like(conc2) + 1 / 3) beta = PD.Beta(conc1, conc2) return torch.sum(D.kl_divergence(beta, prior), -1)
def _kld_gaussian_N_0_1(mu, logvar): """ Internal member for kl-div against a N(0, 1) prior :param mu: mean :param logvar: log-variance :returns: batch_size tensor of kld :rtype: torch.Tensor """ standard_normal = D.Normal(zeros_like(mu), ones_like(logvar)) normal = D.Normal(mu, logvar) return torch.sum(D.kl_divergence(normal, standard_normal), -1)
def z_where_inv(z_where, clip_scale=5.0): # Take a batch of z_where vectors, and compute their "inverse". # That is, for each row compute: # [s,x,y] -> [1/s,-x/s,-y/s] # These are the parameters required to perform the inverse of the # spatial transform performed in the generative model. n = z_where.size(0) out = torch.cat((LocalizedSpatialTransformerFn.ng_ones( [1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1) # Divide all entries by the scale. abs(scale) ensures images arent flipped scale = torch.max(torch.abs(z_where[:, 0:1]), zeros_like(z_where[:, 0:1]) + clip_scale) if torch.sum(scale == 0) > 0: print("tensor scale of {} dim was 0!!".format(scale.shape)) exit(-1) nan_check_and_break(scale, "scale") out = out / scale # out = out / z_where[:, 0:1] return out
def _kld_gaussian_N_0_1(mu, logvar): standard_normal = D.Normal(zeros_like(mu), ones_like(logvar)) normal = D.Normal(mu, logvar) return torch.sum(D.kl_divergence(normal, standard_normal), -1)