class TanhNormal(Distribution): def __init__(self, normal_mean, normal_std, epsilon=1e-6, cuda=False): super(object, self).__init__() self.normal_mean = normal_mean self.normal_std = normal_std self.cuda = cuda self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log(1 - value * value + self.epsilon) def sample(self, return_pretanh_value=False): """ Gradients will and should *not* pass through this operation. See https://github.com/pytorch/pytorch/issues/4620 for discussion. """ z = self.normal.sample().detach() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ sample_mean = torch.zeros(self.normal_mean.size(), dtype=torch.float32, device='cuda' if self.cuda else 'cpu') sample_std = torch.ones(self.normal_std.size(), dtype=torch.float32, device='cuda' if self.cuda else 'cpu') z = (self.normal_mean + self.normal_std * Normal(sample_mean, sample_std).sample()) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
class TanhNormal(torch.distributions.Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log(1 - value * value + self.epsilon) def sample(self, return_pretanh_value=False): z = self.normal.sample() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): z = (self.normal_mean + self.normal_std * Variable( Normal(np.zeros(self.normal_mean.size()), np.ones(self.normal_std.size())).sample())) # z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)