def get_log_p_and_ope_weight(self, state, action): """Get log_p of a state-action and the off-pol weight w.r.t. the old policy.""" pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) pi_o = tensor_to_distribution(self.old_policy(state), **self.policy.dist_params) _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) _, log_p_old = get_entropy_and_log_p(pi_o, action, self.policy.action_scale) ratio = torch.exp(log_p - log_p_old) return log_p, ratio
def get_ope_weight(self, state, action, log_prob_action): """Get off-policy weight of a given transition.""" pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) weight = off_policy_weight(log_p, log_prob_action, full_trajectory=False) return weight
def step_env(environment, state, action, action_scale, pi=None, render=False): """Perform a single step in an environment.""" try: next_state, reward, done, info = environment.step(action) except TypeError: next_state, reward, done, info = environment.step(action.item()) if not isinstance(action, torch.Tensor): action = torch.tensor(action, dtype=torch.get_default_dtype()) if pi is not None: try: with torch.no_grad(): entropy, log_prob_action = get_entropy_and_log_p( pi, action, action_scale ) except RuntimeError: entropy, log_prob_action = 0.0, 1.0 else: entropy, log_prob_action = 0.0, 1.0 observation = Observation( state=state, action=action, reward=reward, next_state=next_state, done=done, entropy=entropy, log_prob_action=log_prob_action, ).to_torch() state = next_state if render: environment.render() return observation, state, done, info
def get_kl_entropy(self, state): """Get kl divergence and current policy at a given state. Compute the separated KL divergence between current and old policy. When the policy is a MultivariateNormal distribution, it compute the divergence that correspond to the mean and the covariance separately. When the policy is a Categorical distribution, it computes the divergence and assigns it to the mean component. The variance component is kept to zero. Parameters ---------- state: torch.Tensor Empirical state distribution. Returns ------- kl_mean: torch.Tensor KL-Divergence due to the change in the mean between current and previous policy. kl_var: torch.Tensor KL-Divergence due to the change in the variance between current and previous policy. entropy: torch.Tensor Entropy of the current policy at the given state. """ pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) pi_old = tensor_to_distribution( self.old_policy(state), **self.policy.dist_params ) try: action = pi.rsample() except NotImplementedError: action = pi.sample() if not self.policy.discrete_action: action = self.policy.action_scale * (action.clamp(-1.0, 1.0)) entropy, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) _, log_p_old = get_entropy_and_log_p(pi_old, action, self.policy.action_scale) kl_mean, kl_var = separated_kl(p=pi_old, q=pi, log_p=log_p_old, log_q=log_p) return kl_mean, kl_var, entropy
def _policy_weighted_nll(self, state, action, weights): """Return weighted policy negative log-likelihood.""" pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) _, action_log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) weighted_log_p = weights.detach() * action_log_p # Clamping is crucial for stability so that it does not converge to a delta. log_likelihood = torch.mean(weighted_log_p.clamp_max(1e-3)) return -log_likelihood
def actor_loss(self, observation): """Use the model to compute the gradient loss.""" state, action = observation.state, observation.action next_state, done = observation.next_state, observation.done # Infer eta. action_mean, action_chol = self.policy(state) with torch.no_grad(): eta = torch.inverse(action_chol) @ ( (action - action_mean).unsqueeze(-1)) # Compute entropy and log_probability. pi = tensor_to_distribution((action_mean, action_chol)) _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) # Compute off-policy weight. with torch.no_grad(): weight = self.get_ope_weight(state, action, observation.log_prob_action) with DisableGradient( self.dynamical_model, self.reward_model, self.termination_model, self.critic_target, ): # Compute re-parameterized policy sample. action = (action_mean + (action_chol @ eta).squeeze(-1)).clamp( -1, 1) # Infer xi. ns_mean, ns_chol = self.dynamical_model(state, action) with torch.no_grad(): xi = torch.inverse(ns_chol) @ ( (next_state - ns_mean).unsqueeze(-1)) # Compute re-parameterized next-state sample. ns = ns_mean + (ns_chol @ xi).squeeze(-1) # Compute reward. r = tensor_to_distribution(self.reward_model(state, action, ns)).rsample() r = r[..., 0] next_v = self.value_function(ns) if isinstance(self.critic, NNEnsembleValueFunction) or isinstance( self.critic, NNEnsembleQFunction): next_v = next_v[..., 0] v = r + self.gamma * next_v * (1 - done) return Loss(policy_loss=-(weight * v)).reduce(self.criterion.reduction)
def forward(self, observation): """Compute the loss and the td-error.""" state, action, reward, next_state, done, *_ = observation behavior_log_p = observation.log_prob_action n_steps = state.shape[1] # done_t indicates if the current state is done. done_t = torch.cat((torch.zeros(done.shape[0], 1), done), -1)[:, :-1] # Compute off-policy correction factor. if self.policy is not None: pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) else: log_p = behavior_log_p correction = self.correction(log_p, behavior_log_p) # Compute Q(state, action) and \E_\pi[Q(next_state, \pi(next_state)]. if isinstance(self.critic, AbstractValueFunction): this_v = self.critic(state) * (1.0 - done_t) next_v = self.critic(next_state) else: this_v = self.critic(state, action) * (1.0 - done_t) if self.policy is not None: next_v = self.value_target(next_state) else: next_v = self.critic(next_state[:, :n_steps - 1], action[:, 1:]) last_v = torch.zeros(next_v.shape[0], 1) if last_v.ndim < next_v.ndim: last_v = last_v.unsqueeze(-1).repeat_interleave( next_v.shape[-1], -1) next_v = torch.cat((next_v, last_v), -1) next_v = next_v * (1.0 - done) # Compute td = r + gamma E\pi[Q(next_state, \pi(next_state)] - Q(state, action). td = self.td(this_v, next_v, reward, correction) # Compute correction factor_t = \Prod_{i=1,t} c_i. correction_factor = torch.cumprod(correction, dim=-1) # Compute discount_t = \gamma ** (t-1) discount = torch.pow(torch.tensor(self.gamma), torch.arange(n_steps)) # Compute target = Q(s, a) + \sum_{i=1,t} discount_i factor_i td_i. See RETRACE. target = this_v + reverse_cumsum(td * discount * correction_factor) return target
def actor_loss(self, observation): """Get Actor loss.""" state, action, *_ = observation pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) entropy, _ = get_entropy_and_log_p(pi, action, self.policy.action_scale) policy_loss = integrate( lambda a: -pi.log_prob(a) * (self.critic(state, self.policy.action_scale * a) - self. value_target(state)).detach(), pi, num_samples=self.num_samples, ).sum() return Loss(policy_loss=policy_loss).reduce(self.criterion.reduction)
def actor_loss(self, observation): """Return primal and dual loss terms from REPS.""" state, action, reward, next_state, done, *r = observation # Compute Scaled TD-Errors value = self.critic(state) # For dual function we need the full gradient, not the semi gradient! target = self.get_value_target(observation) pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params) _, action_log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale) reps_loss = self.reps_loss(action_log_p, value, target) self._info.update(reps_eta=self.reps_loss.eta) return reps_loss + Loss(dual_loss=(1.0 - self.gamma) * value.mean())
def step_model( dynamical_model, reward_model, termination_model, state, action, done=None, action_scale=1.0, pi=None, ): """Perform a single step in an dynamical model.""" # Sample a next state next_state_out = dynamical_model(state, action) next_state_distribution = tensor_to_distribution(next_state_out) if next_state_distribution.has_rsample: next_state = next_state_distribution.rsample() else: next_state = next_state_distribution.sample() # Sample a reward reward_distribution = tensor_to_distribution( reward_model(state, action, next_state) ) if reward_distribution.has_rsample: reward = reward_distribution.rsample().squeeze(-1) else: reward = reward_distribution.sample().squeeze(-1) if done is None: done = torch.zeros_like(reward).bool() reward *= (~done).float() # Check for termination. if termination_model is not None: done = done + ( # "+" is a boolean "or". tensor_to_distribution(termination_model(state, action, next_state)) .sample() .bool() ) if pi is not None: try: entropy, log_prob_action = get_entropy_and_log_p(pi, action, action_scale) except RuntimeError: entropy, log_prob_action = 0.0, 1.0 else: entropy, log_prob_action = 0.0, 1.0 observation = Observation( state=state, action=action, reward=reward, next_state=next_state, done=done.float(), entropy=entropy, log_prob_action=log_prob_action, next_state_scale_tril=next_state_out[-1], ).to_torch() # Update state. next_state = torch.zeros_like(state) next_state[~done] = observation.next_state[~done] # update next state. next_state[done] = state[done] # don't update next state. return observation, next_state, done