def test(model, epoch, batch_size, num_imp_samples, data_loader): model.eval() test_loss = 0 num_samples = 0 marginals = 0 with torch.no_grad(): for i, data in enumerate(data_loader): data = data.to(device) loss, mu, logvar = model.loss(data) test_loss += loss.item() batchsize = data.data.shape[0] q_z = Normal(mu, torch.exp(0.5 * logvar)) zs = q_z.sample_n(num_imp_samples).view(batchsize, num_imp_samples, -1) ll_batch = marginal(model, data, zs) marginals = marginals + torch.sum(ll_batch).item() print( f"Minibatch [{i}/{len(data_loader)}] elbo = {-loss.item()/batchsize}, Log Likelihood of minibatch = {ll_batch.mean().item()}" ) num_samples += data.shape[0] print("\nELBO = {}, marginal probability = {}\n".format( -test_loss / num_samples, marginals / num_samples))
class TanhNormal: """ Represent distribution of X where X ~ tanh(Z), where Z ~ N(mean, std) Note: this is not very numerically stable. Source: https://github.com/vitchyr/rlkit/blob/f136e140a57078c4f0f665051df74dffb1351f33/rlkit/torch/distributions.py """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n): z = self.normal.sample_n(n) return torch.tanh(z), z def log_prob(self, value, pre_tanh_value=None): """ :param value (torch.Tensor): some value, x :param pre_tanh_value (torch.Tensor): arctanh(x) :return: """ if pre_tanh_value is None: # arctanh(x) = 1/2*log((1+x)/(1-x)) pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2.0 if value is None: value = torch.tanh(pre_tanh_value) action = value z = pre_tanh_value log_prob = self.normal.log_prob(z) - torch.log(1 - action.pow(2) + self.epsilon) return log_prob.sum(-1, keepdim=True) def sample(self): """ Gradients will and should not pass through this operation. See https://github.com/pytorch/pytorch/issues/4620 for discussion. """ z = self.normal.sample().detach() return torch.tanh(z), z def rsample(self): """ Sampling in the reparameterization case. """ z = self.normal.rsample() z.requires_grad_() return torch.tanh(z), z
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log(1 - value * value + self.epsilon) def sample(self, return_pretanh_value=False): z = self.normal.sample() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): z = (self.normal_mean + self.normal_std * Variable( Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample())) # z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def sample_action(self,state,batch_size=1): mu,log_std=self.Actor(state) #print(mu,log_std) dst=Normal(0,1) temp=dst.sample_n(self.action_size*batch_size).to(self.device) temp=temp.view((batch_size,self.action_size)) log_prob=dst.log_prob(temp) action=mu+log_std*temp action=torch.tanh(action)*self.action_scale #print(action) return action,log_prob
class Normal(Distribution): def __init__(self, mean, std): self.normal = PyNormal(mean, std) def mean(self): return self.normal.mean def params(self): return [self.normal.mean, self.normal.std] def sample(self): """ Generates a single sample or single batch of samples if the distribution parameters are batched. """ return self.normal.sample() def sample_n(self, n): """ Generates n samples or n batches of samples if the distribution parameters are batched. """ return self.normal.sample_n(n) def log_prob(self, value): """ Returns the log of the probability density/mass function evaluated at `value`. Args: value (Tensor or Variable): """ return self.normal.log_prob(value) def kl(self, other): """ KL-divergence between two Normals: KL[N(u_i, s_i) || N(u_j, s_j)] where params_i = [u_i, s_i] and similarly for j. Returns a tensor with the dimensionality of the location variable. """ if not isinstance(other, Normal): raise ValueError('Impossible') location_i, scale_i = self.params() # [mean, std] location_j, scale_j = other.params() # [mean, std] var_i = scale_i**2. var_j = scale_j**2. term1 = 1. / (2. * var_j) * ( (location_i - location_j)**2. + var_i - var_j) term2 = torch.log(scale_j) - torch.log(scale_i) return term1 + term2
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) """ def __init__(self, normal_mean, normal_std, epsilon=1e-8): self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log(1 - value * value + self.epsilon) def sample(self, return_pretanh_value=False): z = self.normal.sample().detach() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): z = (self.normal_mean + self.normal_std * Normal( torch.zeros(self.normal_mean.size(), device=self.normal_mean.device), torch.ones(self.normal_std.size(), device=self.normal_std.device)).sample()) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): super(TanhNormal, self).__init__() """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return F.tanh(z), z else: return F.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_sigmoid_value: arcsigmoid(x) :return: """ if pre_tanh_value is None: pre_sigmoid_value = torch.log((1 + value) / (1 - value)) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log(1 - value * value + self.epsilon) def sample(self, return_pre_tanh_value=False): z = self.normal.sample() if return_pre_tanh_value: return F.tanh(z), z else: return F.tanh(z)
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: pre_tanh_value = torch.log( (1 + value) / (1 - value) ) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log( 1 - value * value + self.epsilon ) def sample(self, return_pretanh_value=False): """ Gradients will and should *not* pass through this operation. See https://github.com/pytorch/pytorch/issues/4620 for discussion. """ z = self.normal.sample().detach() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = ( self.normal_mean + self.normal_std * Normal( ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size()) ).sample() ) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
class TanhNormal(Distribution): """ [[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/models/tabular.py) **Description** Implements a Normal distribution followed by a Tanh, often used with the Soft Actor-Critic. This implementation also exposes `sample_and_log_prob` and `rsample_and_log_prob`, which returns both samples and log-densities. The log-densities are computed using the pre-activation values for numerical stability. **Credit** Adapted from Vitchyr Pong's RLkit: https://github.com/vitchyr/rlkit/blob/master/rlkit/torch/distributions.py **References** 1. Haarnoja et al. 2018. “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.” arXiv [cs.LG]. 2. Haarnoja et al. 2018. “Soft Actor-Critic Algorithms and Applications.” arXiv [cs.LG]. **Arguments** * **normal_mean** (tensor) - Mean of the Normal distribution. * **normal_std** (tensor) - Standard deviation of the Normal distribution. **Example** ~~~python mean = th.zeros(5) std = th.ones(5) dist = TanhNormal(mean, std) samples = dist.rsample() logprobs = dist.log_prob(samples) # Numerically unstable :( samples, logprobs = dist.rsample_and_log_prob() # Stable :) ~~~ """ def __init__(self, normal_mean, normal_std): self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) def sample_n(self, n): z = self.normal.sample_n(n) return th.tanh(z) def log_prob(self, value): pre_tanh_value = (th.log1p(value) - th.log1p(-value)).mul(0.5) offset = th.log1p(-value**2 + 1e-6) return self.normal.log_prob(pre_tanh_value) - offset def sample(self): z = self.normal.sample().detach() return th.tanh(z) def sample_and_log_prob(self): z = self.normal.sample().detach() value = th.tanh(z) offset = th.log1p(-value**2 + 1e-6) log_prob = self.normal.log_prob(z) - offset return value, log_prob def rsample_and_log_prob(self): z = self.normal.rsample() value = th.tanh(z) offset = th.log1p(-value**2 + 1e-6) log_prob = self.normal.log_prob(z) - offset return value, log_prob def rsample(self): z = self.normal.rsample() z.requires_grad_() return th.tanh(z)
class VRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, latent_size, num_layers_lstm): """Set the hyper-parameters and build the layers.""" super(VRNN, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size + latent_size, hidden_size, num_layers_lstm, batch_first=True) # q(z|x, h) self.q_z = nn.Linear(embed_size + hidden_size, latent_size * 2) # p(z|h) self.prior = nn.Linear(hidden_size, latent_size * 2) # gaussian noise generator for re-paramaterization trick self.normal = Normal(torch.zeros(latent_size, ), torch.ones(latent_size, )) # q(x|z) backwards inference self.q_x = nn.Linear(latent_size + hidden_size, vocab_size) # deterministicly computed distribution over dictionary output (vanilla output) self.det_x = nn.Linear(hidden_size, vocab_size) self.init_weights() def init_weights(self): """Initialize weights.""" self.embed.weight.data.uniform_(-0.1, 0.1) self.q_z.weight.data.uniform_(-0.1, 0.1) self.q_z.bias.data.fill_(0) self.prior.weight.data.uniform_(-0.1, 0.1) self.prior.bias.data.fill_(0) self.q_x.weight.data.uniform_(-0.1, 0.1) self.q_x.bias.data.fill_(0) self.det_x.weight.data.uniform_(-0.1, 0.1) self.det_x.bias.data.fill_(0) """ Conduct forward pass, outputting prior, posterior, and inference distributions z_step: 'all' - compute distributions for every step t - compute distributions during step, t, compute deterministic output for all steps following """ def forward(self, features, captions, lengths, states=None, z_0=None, z_step='all'): """ Decode image feature vectors and generates captions. """ z_padding = to_var( torch.zeros(features.shape[0], self.q_z.out_features / 2)) if z_0 is None: z_0 = z_padding features = torch.cat([features, z_0], dim=1).unsqueeze(1) # conduct initial lstm step to embed image features in internal states h, states = self.lstm(features, states) h = h.squeeze(1) embeddings = self.embed(captions) p_mus, p_sigmas, q_mus, q_sigmas, q_xs, det_xs = [], [], [], [], [], [] for i in range(max(lengths)): if z_step == 'all' or i == z_step: # conduct forward pass for VRNN cell # get tuple of (mean, std dev) for prior from h_tm1 p_mu, p_sigma = self.prior(h).chunk(2, dim=1) # get tuple of (mean, var) for q_z from x_t and h_tm1 q_mu, q_sigma = self.q_z( torch.cat([embeddings[:, i], h], dim=1)).chunk(2, dim=1) # sample from q_z using reparameterization to get z - we take n=batch size samples z = to_var(self.normal.sample_n( q_mu.shape[0])) * q_sigma + q_mu # get q_x from z_t and h_tm1 q_x = self.q_x(torch.cat([z, h], dim=1)) q_x = nn.functional.log_softmax(q_x, dim=1) # perform lstm step to get h_t from x_t, z_t, h_tm1 - unsqueeze for lstm api inputs = torch.cat([embeddings[:, i], z], dim=1).unsqueeze(1) h, states = self.lstm(inputs, states) h = h.squeeze(1) p_mus.append(p_mu) p_sigmas.append(p_sigma) q_mus.append(q_mu) q_sigmas.append(q_sigma) q_xs.append(q_x) else: # conduct forward pass for deterministic LSTM cell inputs = torch.cat([embeddings[:, i], z_padding], dim=1).unsqueeze(1) h, states = self.lstm(inputs, states) h = h.squeeze(1) # if we are decoding, compute distribution over dictionary to output if i > z_step: det_output = self.det_x(h) det_xs.append(det_output) # pack padded sequence to mask out padding terms and flatten batch + step dimensions #[p_mus, p_sigmas, q_mus, q_sigmas, q_xs] = [pack_padded_sequence(torch.stack(t, dim=1), lengths, batch_first=True)[0] for t in [p_mus, p_sigmas, q_mus, q_sigmas, q_xs]] [p_mus, p_sigmas, q_mus, q_sigmas, q_xs] = [t[0] for t in [p_mus, p_sigmas, q_mus, q_sigmas, q_xs]] det_xs = pack_padded_sequence(torch.stack(det_xs, dim=1), [l - z_step - 1 for l in lengths], batch_first=True)[0] return (p_mus, p_sigmas), (q_mus, q_sigmas), q_xs, det_xs """ Encode the first t ground truths, for t = z_step, conducting a variational rnn cell forward pass during step t, in order to obtain q(z_t|x_t, h_t-1) and p(z_t|h_t-1) that will both define (along with x_t, h_t-1) a distribution over h_t, and thereby a distribution over decoded captions beginning with the specified ground truths """ def encode(self, features, ground_truth, z_step, states=None, z_0=None): z_padding = to_var( torch.zeros(features.shape[0], self.q_z.out_features / 2)) if z_0 is None: z_0 = z_padding features = torch.cat([features, z_0], dim=1).unsqueeze(1) # conduct initial lstm step to embed image features in internal states _, states = self.lstm(features, states) embeddings = self.embed(ground_truth) # encode to get z_t, h_t-1 # conduct deterministic forward pass for all steps before z_step for i in range(z_step): inputs = torch.cat([embeddings[:, i], z_padding], dim=1).unsqueeze(1) h, states = self.lstm(inputs, states) h = h.squeeze(1) # conduct forward pass for VRNN cell during z_step # get tuple of (mean, std dev) for prior from h_tm1 p_mu, p_sigma = self.prior(h).chunk(2, dim=1) # get tuple of (mean, var) for q_z from x_t and h_tm1 q_mu, q_sigma = self.q_z(torch.cat([embeddings[:, z_step], h], dim=1)).chunk(2, dim=1) return (p_mu, p_sigma), (q_mu, q_sigma), states, embeddings[:, z_step] """ Decode the remainder of a caption given h_t, c_t and x_t-1 h: [layers X batch X hidden] c: [layers X batch X hidden] x: [batch X embed] """ def decode(self, states, x, z_step): z_padding = to_var(torch.zeros(x.shape[0], self.q_z.out_features / 2)) caption = [] for i in range(20 - z_step - 1): inputs = torch.cat([x, z_padding], dim=1).unsqueeze(1) h, states = self.lstm(inputs, states) h = h.squeeze(1) x = self.det_x(h) # [batch X dictionary] x = x.max(1)[1] # [batch] caption.append(x.data) x = self.embed(x) return torch.stack(caption, dim=1) """ Encode latent distribution at step z_step, sample, then decode latent vector + hidden states into the remaining part of the image caption """ def sample(self, features, ground_truth, states=None, z_0=None, z_step=3): # get prior and posterior densities at z_step (p_mu, p_sigma), (q_mu, q_sigma), states, x = self.encode( features, ground_truth, z_step) # sample to get z z_p = to_var(self.normal.sample_n(features.shape[0])) * p_sigma + p_mu z_q = to_var(self.normal.sample_n(features.shape[0])) * q_sigma + q_mu # get hidden states to give to decoder inputs_p = torch.cat([x, z_p], dim=1).unsqueeze(1) inputs_q = torch.cat([x, z_q], dim=1).unsqueeze(1) h_p, states_p, h_q, states_q = self.lstm(inputs_p, states) + self.lstm( inputs_q, states) h_p, h_q = h_p.squeeze(1), h_q.squeeze(1) # get predicted embedding of next step from h x_p, x_q = self.det_x(h_p), self.det_x(h_q) # [batch X dictionary] x_p, x_q = x_p.max(1)[1], x_q.max(1)[1] # [batch] x_p, x_q = self.embed(x_p), self.embed(x_q) # [batch X embed] # decode from z, x, hidden states the remaining caption captions_p = self.decode(states_p, x_p, z_step) captions_q = self.decode(states_q, x_q, z_step) return captions_p, captions_q
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ Args: normal_mean (Tensor): Mean of the normal distribution normal_std (Tensor): Std of the normal distribution epsilon (Double): Numerical stability epsilon when computing log-prob. """ super(TanhNormal, self).__init__() self._normal_mean = normal_mean self._normal_std = normal_std self._normal = Normal(normal_mean, normal_std) self._epsilon = epsilon @property def mean(self): return self._normal.mean @property def variance(self): return self._normal.variance @property def stddev(self): return self._normal.stddev @property def epsilon(self): return self._epsilon def sample(self, return_pretanh_value=False): # z = self._normal.sample() z = self._normal.sample().detach() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): z = self._normal.rsample() # z = ( # self._normal_mean + # self._normal_std * # Normal( # ptu.zeros(self._normal_mean.size()), # ptu.ones(self._normal_std.size()), # ).sample() # ) if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def sample_n(self, n, return_pre_tanh_value=False): z = self._normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ Returns the log of the probability density function evaluated at `value`. Args: value (Tensor): pre_tanh_value (Tensor): arctan(value) Returns: log_prob (Tensor) """ if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 return self._normal.log_prob(pre_tanh_value) - \ torch.log(1. - value * value + self._epsilon) # return self.normal.log_prob(pre_tanh_value) - \ # torch.log(1. - torch.tanh(pre_tanh_value)**2 + self._epsilon) def cdf(self, value, pre_tanh_value=None): if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 return self._normal.cdf(pre_tanh_value)
# KL-divergence between a diagonal multivariate normal, # and a standard normal distribution (with zero mean and unit variance) # In other words, we are punishing it if it's distribution moves away from a standard normal dist KLD = -0.5 * torch.sum(1 + dist.scale.log() - dist.loc.pow(2) - dist.scale) return BCE + KLD # + # You can try the KLD here with differen't distribution p = Normal(loc=1, scale=2) q = Normal(loc=0, scale=1) kld = torch.distributions.kl.kl_divergence(p, q) # plot the distributions ps = p.sample_n(10000).numpy() qs = q.sample_n(10000).numpy() sns.kdeplot(ps, label='p') sns.kdeplot(qs, label='q') plt.title(f"KLD(p|q) = {kld:2.2f}\nKLD({p}|{q})") plt.legend() plt.show() # - # ## Exercise 1: KLD # # Run the above cell with while changing Q. # # - Use the code above and test if the KLD is higher for distributions that overlap more #
def sample_n(self, n): x = Normal.sample_n(self, n) return F.sigmoid(x)
def sample_n(self, n): x = Normal.sample_n(self, n) return T.exp(x)
import matplotlib.pyplot as plt import torch from torch.distributions import Normal from mmvae_hub.experiment_vis.utils import load_experiment exp_uid = 'polymnist_iwmogfm_multiloss_2021_09_28_11_08_24_742311' exp = load_experiment(_id=exp_uid) exp.set_eval_mode() num_samples = 10 Gf = Normal( torch.zeros(exp.flags.class_dim, device=exp.flags.device), torch.tensor(1 / 2).sqrt() * torch.ones(exp.flags.class_dim, device=exp.flags.device)) z_Gf = Gf.sample_n(num_samples) zss, log_det_J = exp.mm_vae.flow.rev(z_Gf) rec_mods = {} for mod_key, mod in exp.mm_vae.modalities.items(): rec_mods[mod_key] = mod.calc_likelihood(None, class_embeddings=zss).mean for k in range(num_samples): plt.figure() plt.subplot(1, 3, 1) plt.imshow(rec_mods['m0'][k].detach().cpu().numpy().swapaxes(0, -1)) plt.subplot(1, 3, 2) plt.imshow(rec_mods['m1'][k].detach().cpu().numpy().swapaxes(0, -1)) plt.subplot(1, 3, 3) plt.imshow(rec_mods['m2'][k].detach().cpu().numpy().swapaxes(0, -1))
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ Adapted from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73 This formula is mathematically equivalent to log(1 - tanh(x)^2). Derivation: log(1 - tanh(x)^2) = log(sech(x)^2) = 2 * log(sech(x)) = 2 * log(2e^-x / (e^-2x + 1)) = 2 * (log(2) - x - log(e^-2x + 1)) = 2 * (log(2) - x - softplus(-2x)) :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: value = torch.clamp(value, -0.999999, 0.999999) # pre_tanh_value = torch.log( # (1+value) / (1-value) # ) / 2 pre_tanh_value = torch.log(1 + value) / 2 - torch.log(1 - value) / 2 # ) / 2 return self.normal.log_prob(pre_tanh_value) - 2. * ( ptu.from_numpy(np.log([2.])) - pre_tanh_value - torch.nn.functional.softplus(-2. * pre_tanh_value)) def sample(self, return_pretanh_value=False): """ Gradients will and should *not* pass through this operation. See https://github.com/pytorch/pytorch/issues/4620 for discussion. """ z = self.normal.sample().detach() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = (self.normal_mean + self.normal_std * Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample()) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)