def forward(self, kl_mean, kl_var=None): """Return primal and dual loss terms from MMPO. Parameters ---------- kl_mean : torch.Tensor A float corresponding to the KL divergence. kl_var : torch.Tensor A float corresponding to the KL divergence. """ if self.epsilon_mean == 0.0 and not self.regularization: return Loss() if kl_var is None: kl_var = torch.zeros_like(kl_mean) kl_mean, kl_var = kl_mean.mean(), kl_var.mean() reg_loss = self.eta_mean * kl_mean + self.eta_var * kl_var if self.regularization: return Loss(reg_loss=reg_loss) else: if self.separated_kl: mean_loss = self._eta_mean() * (self.epsilon_mean - kl_mean).detach() var_loss = self._eta_var() * (self.epsilon_var - kl_var).detach() dual_loss = mean_loss + var_loss else: dual_loss = self._eta_mean() * (self.epsilon_mean - kl_mean).detach() return Loss(dual_loss=dual_loss, reg_loss=reg_loss)
def forward(self, entropy): """Return primal and dual loss terms from entropy loss. Parameters ---------- entropy: torch.tensor. """ if self.target_entropy == 0.0 and not self.regularization: return Loss() dual_loss = self._eta() * (entropy - self.target_entropy).detach() reg_loss = -self.eta * entropy return Loss(dual_loss=dual_loss, reg_loss=reg_loss)
def critic_loss(self, observation): """Get critic loss. This is usually computed using fitted value iteration and semi-gradients. critic_loss = criterion(pred_q, target_q.detach()). Parameters ---------- observation: Observation. Sampled observations. It is of shape B x N x d, where: - B is the batch size - N is the N-step return - d is the dimension of the attribute. Returns ------- loss: Loss. Loss with parameters loss, critic_loss, and td_error filled. """ if self.critic is None: return Loss() pred_q = self.get_value_prediction(observation) # Get target_q with semi-gradients. with torch.no_grad(): target_q = self.get_value_target(observation) if pred_q.shape != target_q.shape: # Reshape in case of ensembles. assert isinstance(self.critic, NNEnsembleQFunction) target_q = target_q.unsqueeze(-1).repeat_interleave( self.critic.num_heads, -1 ) td_error = pred_q - target_q # no gradients for td-error. if self.criterion.reduction == "mean": td_error = torch.mean(td_error) elif self.criterion.reduction == "sum": td_error = torch.sum(td_error) critic_loss = self.criterion(pred_q, target_q) if isinstance(self.critic, NNEnsembleQFunction): # Ensembles have last dimension as ensemble head; sum all ensembles. critic_loss = critic_loss.sum(-1) td_error = td_error.sum(-1) # Take mean over time coordinate. critic_loss = critic_loss.mean(-1) td_error = td_error.mean(-1) return Loss(critic_loss=critic_loss, td_error=td_error)
def forward(self, inequality_value): """Return primal and dual loss terms from entropy loss. Parameters ---------- inequality_value: torch.tensor. """ if self.inequality_zero == 0.0 and not self.regularization: return Loss() dual_loss = self._dual() * (self.inequality_zero - inequality_value).detach() reg_loss = self._dual().detach() * inequality_value return Loss(dual_loss=dual_loss, reg_loss=reg_loss)
def model_augmented_critic_loss(self, observation): """Get Model-Based critic-loss.""" with torch.no_grad(): state, action = observation.state[..., 0, :], observation.action[..., 0, :] sim_observation = self.simulate(state, self.policy, initial_action=action, stack_obs=True) if not self.td_k: sim_observation.state = observation.state[..., :1, :] sim_observation.action = observation.action[..., :1, :] pred_q = self.base_algorithm.get_value_prediction(sim_observation) # Get target_q with semi-gradients. with torch.no_grad(): target_q = self.get_value_target(sim_observation) if not self.td_k: target_q = target_q.reshape(self.num_samples, *pred_q.shape[:2]).mean(0) if pred_q.shape != target_q.shape: # Reshape in case of ensembles. assert isinstance(self.critic, NNEnsembleQFunction) target_q = target_q.unsqueeze(-1).repeat_interleave( self.critic.num_heads, -1) critic_loss = self.base_algorithm.criterion(pred_q, target_q) return Loss(critic_loss=critic_loss)
def forward(self, observation, idx=None): """Compute losses at state/idx pairs.""" state = observation.state if idx is None: idx = torch.arange(state.shape[0]) return Loss(dual_loss=self.dual(observation, idx=idx) + self.get_discount_dual_loss(state, idx))
def forward(self, observation): """Compute the losses. Given an Observation, it will compute the losses. Given a list of Trajectories, it tries to stack them to vectorize operations. If it fails, will iterate over the trajectories. """ if isinstance(observation, Observation): trajectories = [observation] elif len(observation) > 1: try: # When possible, stack to parallelize the trajectories. # This requires all trajectories to be equal of length. trajectories = [stack_list_of_tuples(observation)] except RuntimeError: trajectories = observation else: trajectories = observation self.reset_info() loss = Loss() for trajectory in trajectories: loss += self.actor_loss(trajectory) loss += self.critic_loss(trajectory) loss += self.regularization_loss(trajectory, len(trajectories)) return loss / len(trajectories)
def forward(self, action_log_p, value, target): """Return primal and dual loss terms from REPS. Parameters ---------- action_log_p : torch.Tensor A [state_batch, 1] tensor of log probabilities of the corresponding actions under the policy. value: torch.Tensor The value function (with gradients) evaluated at V(s) target: torch.Tensor The value target (with gradients) evaluated at r + gamma V(s') """ td = target - value weights = td / self._eta() normalizer = torch.logsumexp(weights, dim=0) dual_loss = self._eta() * (self.epsilon + normalizer) # Clamping is crucial for stability so that it does not converge to a delta. weighted_log_p = torch.exp(weights).clamp_max( 1e2).detach() * action_log_p log_likelihood = weighted_log_p.mean() return Loss(policy_loss=-log_likelihood, dual_loss=dual_loss, td_error=td.mean())
def forward(self, observation): """Rollout model and call base algorithm with transitions.""" self.base_algorithm.reset_info() loss = Loss() loss += self.base_algorithm.actor_loss(observation) loss += self.model_augmented_critic_loss(observation) loss += self.base_algorithm.regularization_loss(observation) return loss
def forward(self, observation): """Compute path-wise loss.""" if self.policy is None or self.critic is None: return Loss() state = observation.state pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) action = self.policy.action_scale * pi.rsample().clamp(-1, 1) with DisableGradient(self.critic): q = self.critic(state, action) if isinstance(self.critic, NNEnsembleQFunction): q = q[..., 0] # Take mean over time coordinate. if q.dim() < 1: q = q.mean(dim=1) return Loss(policy_loss=-q)
def actor_loss(self, observation): """Use the model to compute the gradient loss.""" state, action = observation.state, observation.action next_state, done = observation.next_state, observation.done # Infer eta. action_mean, action_chol = self.policy(state) with torch.no_grad(): eta = torch.inverse(action_chol) @ ( (action - action_mean).unsqueeze(-1)) # Compute entropy and log_probability. pi = tensor_to_distribution((action_mean, action_chol)) _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) # Compute off-policy weight. with torch.no_grad(): weight = self.get_ope_weight(state, action, observation.log_prob_action) with DisableGradient( self.dynamical_model, self.reward_model, self.termination_model, self.critic_target, ): # Compute re-parameterized policy sample. action = (action_mean + (action_chol @ eta).squeeze(-1)).clamp( -1, 1) # Infer xi. ns_mean, ns_chol = self.dynamical_model(state, action) with torch.no_grad(): xi = torch.inverse(ns_chol) @ ( (next_state - ns_mean).unsqueeze(-1)) # Compute re-parameterized next-state sample. ns = ns_mean + (ns_chol @ xi).squeeze(-1) # Compute reward. r = tensor_to_distribution(self.reward_model(state, action, ns)).rsample() r = r[..., 0] next_v = self.value_function(ns) if isinstance(self.critic, NNEnsembleValueFunction) or isinstance( self.critic, NNEnsembleQFunction): next_v = next_v[..., 0] v = r + self.gamma * next_v * (1 - done) return Loss(policy_loss=-(weight * v)).reduce(self.criterion.reduction)
def model_augmented_critic_loss(self, observation): """Get Model-Based critic-loss.""" pred_q = self.base_algorithm.get_value_prediction(observation) # Get target_q with semi-gradients. with torch.no_grad(): target_q = self.get_value_target(observation) if pred_q.shape != target_q.shape: # Reshape in case of ensembles. assert isinstance(self.critic, NNEnsembleQFunction) target_q = target_q.unsqueeze(-1).repeat_interleave( self.critic.num_heads, -1) critic_loss = self.base_algorithm.criterion(pred_q, target_q) return Loss(critic_loss=critic_loss)
def score_actor_loss(self, observation, linearized=False): """Get score actor loss for policy gradients.""" state, action, reward, next_state, done, *r = observation log_p, ratio = self.get_log_p_and_ope_weight(state, action) with torch.no_grad(): adv = self.returns(observation) if self.standardize_returns: adv = (adv - adv.mean()) / (adv.std() + self.eps) if linearized: score = ratio * adv else: score = discount_sum(log_p * adv, self.gamma) return Loss(policy_loss=-score)
def actor_loss(self, observation): """Get Actor loss.""" state, action, *_ = observation pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) entropy, _ = get_entropy_and_log_p(pi, action, self.policy.action_scale) policy_loss = integrate( lambda a: -pi.log_prob(a) * (self.critic(state, self.policy.action_scale * a) - self. value_target(state)).detach(), pi, num_samples=self.num_samples, ).sum() return Loss(policy_loss=policy_loss).reduce(self.criterion.reduction)
def actor_loss(self, observation): """Return primal and dual loss terms from REPS.""" state, action, reward, next_state, done, *r = observation # Compute Scaled TD-Errors value = self.critic(state) # For dual function we need the full gradient, not the semi gradient! target = self.get_value_target(observation) pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) _, action_log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) reps_loss = self.reps_loss(action_log_p, value, target) self._info.update(reps_eta=self.reps_loss.eta) return reps_loss + Loss(dual_loss=(1.0 - self.gamma) * value.mean())
def actor_loss(self, trajectory): """Get actor loss.""" state, action, reward, next_state, done, *r = trajectory log_p, ratio = self.get_log_p_and_ope_weight(state, action) with torch.no_grad(): adv = self.returns(trajectory) if self.standardize_returns: adv = (adv - adv.mean()) / (adv.std() + self.eps) # Compute surrogate loss. weighted_advantage = ratio * adv clipped_advantage = ratio.clamp(1 - self.epsilon(), 1 + self.epsilon()) * adv surrogate_loss = -torch.min(weighted_advantage, clipped_advantage) # Instead of using the Trust-region, TRPO takes the minimum in line 80. return Loss(policy_loss=surrogate_loss).reduce( self.criterion.reduction)
def actor_loss(self, observation): """Return primal and dual loss terms from REPS.""" state, action, reward, next_state, done, *r = observation # Compute Scaled TD-Errors value = self.critic(state) # For dual function we need the full gradient, not the semi gradient! target = self.get_value_target(observation) td = target - value weights = td / self.eta() normalizer = torch.logsumexp(weights, dim=0) dual = self.eta() * (self.epsilon + normalizer) + (1.0 - self.gamma) * value nll = self._policy_weighted_nll(state, action, weights) return Loss(dual_loss=dual.mean(), policy_loss=nll, td_error=td)
def actor_loss(self, observation): """Return primal and dual loss terms from Q-REPS.""" state, action, reward, next_state, done, *r = observation # Calculate dual variables value = self.critic(state) target = self.get_value_target(observation) q_value = self.q_function(state, action) td = target - q_value self._info.update(td=td) # Calculate weights. weights_td = self.eta() * td # type: torch.Tensor if weights_td.ndim == 1: weights_td = weights_td.unsqueeze(-1) dual = 1 / self.eta() * torch.logsumexp(weights_td, dim=-1) dual += (1 - self.gamma) * value.squeeze(-1) return Loss(dual_loss=dual.mean(), td_error=td)
def actor_loss(self, observation): """Get actor loss. This is different for each algorithm. Parameters ---------- observation: Observation. Sampled observations. It is of shape B x N x d, where: - B is the batch size - N is the N-step return - d is the dimension of the attribute. Returns ------- loss: Loss. Loss with parameters loss, policy_loss, and regularization_loss filled. """ return Loss()
def actor_loss(self, observation) -> Loss: """Compute Actor loss.""" state, action = observation.state[..., 0, :], observation.action[..., 0, :] action_mean, action_chol = self.policy(state) # Infer eta. with torch.no_grad(): delta = action / self.policy.action_scale - action_mean eta = torch.inverse(action_chol) @ delta.unsqueeze(-1) # Compute re-parameterized policy sample. action = self.policy.action_scale * ( action_mean + (action_chol @ eta).squeeze(-1)).clamp(-1.0, 1.0) # Propagate gradient. with DisableGradient(self.critic): q = self.critic(observation.state[..., 0, :], action) if isinstance(self.critic, NNEnsembleQFunction): q = q[..., 0] return Loss(policy_loss=-q).reduce(self.criterion.reduction)
def actor_loss(self, observation): """Compute the losses for one step of MPO. Parameters ---------- observation : Observation The states at which to compute the losses. """ state = observation.state value_prediction = self.critic(state) with torch.no_grad(): value_estimate, obs = mb_return( state=state, dynamical_model=self.dynamical_model, policy=self.old_policy, reward_model=self.reward_model, num_steps=1, gamma=self.gamma, value_function=self.critic_target, num_samples=self.num_samples, reward_transformer=self.reward_transformer, termination_model=self.termination_model, reduction="min", ) q_values = value_estimate log_p, _ = self.get_log_p_and_ope_weight(obs.state, obs.action) # Since actions come from policy, value is the expected q-value mpo_loss = self.mpo_loss(q_values=q_values, action_log_p=log_p.squeeze(-1)) value_loss = self.criterion(value_prediction, q_values.mean(dim=0)) td_error = value_prediction - q_values.mean(dim=0) critic_loss = Loss(critic_loss=value_loss, td_error=td_error) self._info.update(eta=self.mpo_loss.eta) return mpo_loss.reduce(self.criterion.reduction) + critic_loss
def forward(self, q_values, action_log_p): """Return primal and dual loss terms from MPO. Parameters ---------- q_values : torch.Tensor A [n_action_samples, state_batch, 1] tensor of values for state-action pairs. action_log_p : torch.Tensor A [n_action_samples, state_batch, 1] tensor of log probabilities of the corresponding actions under the policy. """ # Make sure the lagrange multipliers stay positive. # self.project_etas() # E-step: Solve Problem (7). # Create a weighed, sample-based representation of the optimal policy q Eq(8). # Compute the dual loss for the constraint KL(q || old_pi) < eps. q_values = q_values.detach() * (torch.tensor(1.0) / self._eta()) normalizer = torch.logsumexp(q_values, dim=0) num_actions = torch.tensor(1.0 * action_log_p.shape[0]) dual_loss = self._eta() * ( self.epsilon + torch.mean(normalizer) - torch.log(num_actions) ) # non-parametric representation of the optimal policy. weights = torch.exp(q_values - normalizer.detach()) # M-step: # E-step: Solve Problem (10). # Fit the parametric policy to the representation form the E-step. # Maximize the log_likelihood of the weighted log probabilities, subject to the # KL divergence between the old_pi and the new_pi to be smaller than epsilon. weighted_log_p = torch.sum(weights * action_log_p, dim=0) log_likelihood = weighted_log_p return Loss(policy_loss=-log_likelihood.mean(), dual_loss=dual_loss)
def critic_loss(self, observation): """Return 0 loss. The actor loss returns both the critic and the actor.""" return Loss()
def critic_loss(self, observation) -> Loss: """Get the critic loss.""" return Loss()