def get_log_prob_entropy(self, x, actions): action_mean, action_log_std, action_std = self.forward(x) #[batch_size, a_dim] normal = Normal(loc=action_mean, scale=action_std) diagn = Independent(normal, 1) log_prob = diagn.log_prob(actions).unsqueeze(dim=1) entropy = diagn.entropy()[0] #prob = MultivariateNormal(loc=action_mean, scale_tril=torch.diag(action_std[0,:]**2)) #log_prob = prob.log_prob(actions).unsqueeze(dim=1) #entropy = prob.entropy()[0] return log_prob, entropy
class TanhNormal(torch.distributions.Distribution): r"""A distribution induced by applying a tanh transformation to a Gaussian random variable. Algorithms like SAC and Pearl use this transformed distribution. It can be thought of as a distribution of X where :math:`Y ~ \mathcal{N}(\mu, \sigma)` :math:`X = tanh(Y)` Args: loc (torch.Tensor): The mean of this distribution. scale (torch.Tensor): The stdev of this distribution. """ # noqa: 501 def __init__(self, loc, scale): self._normal = Independent(Normal(loc, scale), 1) super().__init__() def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6): """The log likelihood of a sample on the this Tanh Distribution. Args: value (torch.Tensor): The sample whose loglikelihood is being computed. pre_tanh_value (torch.Tensor): The value prior to having the tanh function applied to it but after it has been sampled from the normal distribution. epsilon (float): Regularization constant. Making this value larger makes the computation more stable but less precise. Note: when pre_tanh_value is None, an estimate is made of what the value is. This leads to a worse estimation of the log_prob. If the value being used is collected from functions like `sample` and `rsample`, one can instead use functions like `sample_return_pre_tanh_value` or `rsample_return_pre_tanh_value` Returns: torch.Tensor: The log likelihood of value on the distribution. """ # pylint: disable=arguments-differ if pre_tanh_value is None: pre_tanh_value = torch.log( (1 + epsilon + value) / (1 + epsilon - value)) / 2 norm_lp = self._normal.log_prob(pre_tanh_value) ret = (norm_lp - torch.sum( torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon), axis=-1)) return ret def sample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients `do not` pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ with torch.no_grad(): return self.rsample(sample_shape=sample_shape) def rsample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ z = self._normal.rsample(sample_shape) return torch.tanh(z) def rsample_with_pre_tanh_value(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal distribution. Returns the sampled value before the tanh transform is applied and the sampled value with the tanh transform applied to it. Args: sample_shape (list): shape of the return. Note: Gradients pass through this operation. Returns: torch.Tensor: Samples from this distribution. torch.Tensor: Samples from the underlying :obj:`torch.distributions.Normal` distribution, prior to being transformed with `tanh`. """ z = self._normal.rsample(sample_shape) return z, torch.tanh(z) def cdf(self, value): """Returns the CDF at the value. Returns the cumulative density/mass function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.cdf(value) def icdf(self, value): """Returns the icdf function evaluated at `value`. Returns the icdf function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.icdf(value) @classmethod def _from_distribution(cls, new_normal): """Construct a new TanhNormal distribution from a normal distribution. Args: new_normal (Independent(Normal)): underlying normal dist for the new TanhNormal distribution. Returns: TanhNormal: A new distribution whose underlying normal dist is new_normal. """ # pylint: disable=protected-access new = cls(torch.zeros(1), torch.zeros(1)) new._normal = new_normal return new def expand(self, batch_shape, _instance=None): """Returns a new TanhNormal distribution. (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance(instance): new instance provided by subclasses that need to override `.expand`. Returns: Instance: New distribution instance with batch dimensions expanded to `batch_size`. """ new_normal = self._normal.expand(batch_shape, _instance) new = self._from_distribution(new_normal) return new def enumerate_support(self, expand=True): """Returns tensor containing all values supported by a discrete dist. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Note: Calls the enumerate_support function of the underlying normal distribution. Returns: torch.Tensor: Tensor iterating over dimension 0. """ return self._normal.enumerate_support(expand) @property def mean(self): """torch.Tensor: mean of the distribution.""" return torch.tanh(self._normal.mean) @property def variance(self): """torch.Tensor: variance of the underlying normal distribution.""" return self._normal.variance def entropy(self): """Returns entropy of the underlying normal distribution. Returns: torch.Tensor: entropy of the underlying normal distribution. """ return self._normal.entropy() @staticmethod def _clip_but_pass_gradient(x, lower=0., upper=1.): """Clipping function that allows for gradients to flow through. Args: x (torch.Tensor): value to be clipped lower (float): lower bound of clipping upper (float): upper bound of clipping Returns: torch.Tensor: x clipped between lower and upper. """ clip_up = (x > upper).float() clip_low = (x < lower).float() with torch.no_grad(): clip = ((upper - x) * clip_up + (lower - x) * clip_low) return x + clip def __repr__(self): """Returns the parameterization of the distribution. Returns: str: The parameterization of the distribution and underlying distribution. """ return self.__class__.__name__