def get_logp(mean, std, action, expand=None): if expand is not None: dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1).expand(expand) else: dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1) return dist.log_prob(action)
def choose_action(self, state): state = torch.from_numpy(state).float().unsqueeze(0) mean, logstd = self.policy(state.cuda()) dist = Independent(Normal(mean.squeeze(), torch.exp(logstd)), 1) action = dist.sample() log_prob = dist.log_prob(action) return action.squeeze().cpu().numpy(), log_prob.item()
def get_log_prob_entropy(self, x, actions): action_mean, action_log_std, action_std = self.forward(x) #[batch_size, a_dim] normal = Normal(loc=action_mean, scale=action_std) diagn = Independent(normal, 1) log_prob = diagn.log_prob(actions).unsqueeze(dim=1) entropy = diagn.entropy()[0] #prob = MultivariateNormal(loc=action_mean, scale_tril=torch.diag(action_std[0,:]**2)) #log_prob = prob.log_prob(actions).unsqueeze(dim=1) #entropy = prob.entropy()[0] return log_prob, entropy
class TanhNormal(torch.distributions.Distribution): r"""A distribution induced by applying a tanh transformation to a Gaussian random variable. Algorithms like SAC and Pearl use this transformed distribution. It can be thought of as a distribution of X where :math:`Y ~ \mathcal{N}(\mu, \sigma)` :math:`X = tanh(Y)` Args: loc (torch.Tensor): The mean of this distribution. scale (torch.Tensor): The stdev of this distribution. """ # noqa: 501 def __init__(self, loc, scale): self._normal = Independent(Normal(loc, scale), 1) super().__init__() def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6): """The log likelihood of a sample on the this Tanh Distribution. Args: value (torch.Tensor): The sample whose loglikelihood is being computed. pre_tanh_value (torch.Tensor): The value prior to having the tanh function applied to it but after it has been sampled from the normal distribution. epsilon (float): Regularization constant. Making this value larger makes the computation more stable but less precise. Note: when pre_tanh_value is None, an estimate is made of what the value is. This leads to a worse estimation of the log_prob. If the value being used is collected from functions like `sample` and `rsample`, one can instead use functions like `sample_return_pre_tanh_value` or `rsample_return_pre_tanh_value` Returns: torch.Tensor: The log likelihood of value on the distribution. """ # pylint: disable=arguments-differ if pre_tanh_value is None: pre_tanh_value = torch.log( (1 + epsilon + value) / (1 + epsilon - value)) / 2 norm_lp = self._normal.log_prob(pre_tanh_value) ret = (norm_lp - torch.sum( torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon), axis=-1)) return ret def sample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients `do not` pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ with torch.no_grad(): return self.rsample(sample_shape=sample_shape) def rsample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ z = self._normal.rsample(sample_shape) return torch.tanh(z) def rsample_with_pre_tanh_value(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal distribution. Returns the sampled value before the tanh transform is applied and the sampled value with the tanh transform applied to it. Args: sample_shape (list): shape of the return. Note: Gradients pass through this operation. Returns: torch.Tensor: Samples from this distribution. torch.Tensor: Samples from the underlying :obj:`torch.distributions.Normal` distribution, prior to being transformed with `tanh`. """ z = self._normal.rsample(sample_shape) return z, torch.tanh(z) def cdf(self, value): """Returns the CDF at the value. Returns the cumulative density/mass function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.cdf(value) def icdf(self, value): """Returns the icdf function evaluated at `value`. Returns the icdf function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.icdf(value) @classmethod def _from_distribution(cls, new_normal): """Construct a new TanhNormal distribution from a normal distribution. Args: new_normal (Independent(Normal)): underlying normal dist for the new TanhNormal distribution. Returns: TanhNormal: A new distribution whose underlying normal dist is new_normal. """ # pylint: disable=protected-access new = cls(torch.zeros(1), torch.zeros(1)) new._normal = new_normal return new def expand(self, batch_shape, _instance=None): """Returns a new TanhNormal distribution. (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance(instance): new instance provided by subclasses that need to override `.expand`. Returns: Instance: New distribution instance with batch dimensions expanded to `batch_size`. """ new_normal = self._normal.expand(batch_shape, _instance) new = self._from_distribution(new_normal) return new def enumerate_support(self, expand=True): """Returns tensor containing all values supported by a discrete dist. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Note: Calls the enumerate_support function of the underlying normal distribution. Returns: torch.Tensor: Tensor iterating over dimension 0. """ return self._normal.enumerate_support(expand) @property def mean(self): """torch.Tensor: mean of the distribution.""" return torch.tanh(self._normal.mean) @property def variance(self): """torch.Tensor: variance of the underlying normal distribution.""" return self._normal.variance def entropy(self): """Returns entropy of the underlying normal distribution. Returns: torch.Tensor: entropy of the underlying normal distribution. """ return self._normal.entropy() @staticmethod def _clip_but_pass_gradient(x, lower=0., upper=1.): """Clipping function that allows for gradients to flow through. Args: x (torch.Tensor): value to be clipped lower (float): lower bound of clipping upper (float): upper bound of clipping Returns: torch.Tensor: x clipped between lower and upper. """ clip_up = (x > upper).float() clip_low = (x < lower).float() with torch.no_grad(): clip = ((upper - x) * clip_up + (lower - x) * clip_low) return x + clip def __repr__(self): """Returns the parameterization of the distribution. Returns: str: The parameterization of the distribution and underlying distribution. """ return self.__class__.__name__
).cuda() old_values = torch.cat( [torch.tensor(sample[4]).unsqueeze(0) for sample in rollout_batch] ).cuda() old_log_probs = torch.cat( [torch.tensor(sample[5]).unsqueeze(0) for sample in rollout_batch] ).cuda() means, logstd = policy(states) dist = Independent( Normal( means, torch.exp(logstd.unsqueeze(0).expand(batch_size, -1)) ), 1, ) log_probs = dist.log_prob(actions.squeeze()) values = value_fn(states) clipped_values = old_values + torch.clamp( values - old_values, -args.clip_range, args.clip_range ) l_vf1 = (values - targets).pow(2) l_vf2 = (clipped_values - targets).pow(2) value_fn_loss = 0.5 * torch.max(l_vf1, l_vf2).mean() value_fn_loss.backward() clip_grad_norm_(value_fn.parameters(), args.max_grad_norm) value_fn_opt.step() value_fn_opt.zero_grad() k = logstd.shape[0] entropy = (k / 2) * (1 + math.log(2 * math.pi)) + 0.5 * torch.log(