def evaluate_mj(mu, sigma, actions, clamp=None): """ Evaluate continuous actions/state batchwise. """ cov = torch.zeros(mu.shape[0],mu.shape[1],mu.shape[1]) # To diagonalize sigma batch wise. Converst batch_sizexD ---> batch_sizexDxD cov.as_strided(sigma.size(), [cov.stride(0), cov.size(2) + 1]).copy_(sigma) gauss = MultivariateNormal(mu, cov) log_probs = gauss.log_prob(actions) if clamp is not None: return torch.clamp(log_probs, min=-clamp, max=clamp), gauss.entropy() else: return log_probs, gauss.entropy()
def sample(self, x: torch.Tensor, raw_action: Optional[torch.Tensor] = None, deterministic: bool = False) -> Tuple[torch.Tensor, ...]: mean, log_std = self.forward(x) covariance = torch.diag_embed(log_std.exp()) dist = MultivariateNormal(loc=mean, scale_tril=covariance) if not raw_action: if self._reparameterize: raw_action = dist.rsample() else: raw_action = dist.sample() action = torch.tanh(raw_action) if self._squash else raw_action log_prob = dist.log_prob(raw_action).unsqueeze(-1) if self._squash: log_prob -= self._squash_correction(raw_action) entropy = dist.entropy().unsqueeze(-1) if deterministic: action = torch.tanh(dist.mean) return action, log_prob, entropy
def get_entropy(self, state): bsize = state.size(0) mu, std = self.forward(state) dist = MultivariateNormal(loc=mu, scale_tril=torch.diag_embed(std)) entropy = dist.entropy().view(bsize, 1) return entropy