예제 #1
0
    def forward(self, kl_mean, kl_var=None):
        """Return primal and dual loss terms from MMPO.

        Parameters
        ----------
        kl_mean : torch.Tensor
            A float corresponding to the KL divergence.
        kl_var : torch.Tensor
            A float corresponding to the KL divergence.
        """
        if self.epsilon_mean == 0.0 and not self.regularization:
            return Loss()
        if kl_var is None:
            kl_var = torch.zeros_like(kl_mean)

        kl_mean, kl_var = kl_mean.mean(), kl_var.mean()
        reg_loss = self.eta_mean * kl_mean + self.eta_var * kl_var
        if self.regularization:
            return Loss(reg_loss=reg_loss)
        else:
            if self.separated_kl:
                mean_loss = self._eta_mean() * (self.epsilon_mean -
                                                kl_mean).detach()
                var_loss = self._eta_var() * (self.epsilon_var -
                                              kl_var).detach()
                dual_loss = mean_loss + var_loss
            else:
                dual_loss = self._eta_mean() * (self.epsilon_mean -
                                                kl_mean).detach()

            return Loss(dual_loss=dual_loss, reg_loss=reg_loss)
예제 #2
0
    def forward(self, entropy):
        """Return primal and dual loss terms from entropy loss.

        Parameters
        ----------
        entropy: torch.tensor.
        """
        if self.target_entropy == 0.0 and not self.regularization:
            return Loss()
        dual_loss = self._eta() * (entropy - self.target_entropy).detach()
        reg_loss = -self.eta * entropy
        return Loss(dual_loss=dual_loss, reg_loss=reg_loss)
예제 #3
0
    def critic_loss(self, observation):
        """Get critic loss.

        This is usually computed using fitted value iteration and semi-gradients.
        critic_loss = criterion(pred_q, target_q.detach()).

        Parameters
        ----------
        observation: Observation.
            Sampled observations.
            It is of shape B x N x d, where:
                - B is the batch size
                - N is the N-step return
                - d is the dimension of the attribute.

        Returns
        -------
        loss: Loss.
            Loss with parameters loss, critic_loss, and td_error filled.
        """
        if self.critic is None:
            return Loss()

        pred_q = self.get_value_prediction(observation)

        # Get target_q with semi-gradients.
        with torch.no_grad():
            target_q = self.get_value_target(observation)
            if pred_q.shape != target_q.shape:  # Reshape in case of ensembles.
                assert isinstance(self.critic, NNEnsembleQFunction)
                target_q = target_q.unsqueeze(-1).repeat_interleave(
                    self.critic.num_heads, -1
                )

            td_error = pred_q - target_q  # no gradients for td-error.
            if self.criterion.reduction == "mean":
                td_error = torch.mean(td_error)
            elif self.criterion.reduction == "sum":
                td_error = torch.sum(td_error)

        critic_loss = self.criterion(pred_q, target_q)

        if isinstance(self.critic, NNEnsembleQFunction):
            # Ensembles have last dimension as ensemble head; sum all ensembles.
            critic_loss = critic_loss.sum(-1)
            td_error = td_error.sum(-1)

        # Take mean over time coordinate.
        critic_loss = critic_loss.mean(-1)
        td_error = td_error.mean(-1)

        return Loss(critic_loss=critic_loss, td_error=td_error)
예제 #4
0
    def forward(self, inequality_value):
        """Return primal and dual loss terms from entropy loss.

        Parameters
        ----------
        inequality_value: torch.tensor.
        """
        if self.inequality_zero == 0.0 and not self.regularization:
            return Loss()
        dual_loss = self._dual() * (self.inequality_zero -
                                    inequality_value).detach()
        reg_loss = self._dual().detach() * inequality_value
        return Loss(dual_loss=dual_loss, reg_loss=reg_loss)
예제 #5
0
    def model_augmented_critic_loss(self, observation):
        """Get Model-Based critic-loss."""
        with torch.no_grad():
            state, action = observation.state[...,
                                              0, :], observation.action[...,
                                                                        0, :]
            sim_observation = self.simulate(state,
                                            self.policy,
                                            initial_action=action,
                                            stack_obs=True)

        if not self.td_k:
            sim_observation.state = observation.state[..., :1, :]
            sim_observation.action = observation.action[..., :1, :]

        pred_q = self.base_algorithm.get_value_prediction(sim_observation)

        # Get target_q with semi-gradients.
        with torch.no_grad():
            target_q = self.get_value_target(sim_observation)
            if not self.td_k:
                target_q = target_q.reshape(self.num_samples,
                                            *pred_q.shape[:2]).mean(0)
            if pred_q.shape != target_q.shape:  # Reshape in case of ensembles.
                assert isinstance(self.critic, NNEnsembleQFunction)
                target_q = target_q.unsqueeze(-1).repeat_interleave(
                    self.critic.num_heads, -1)

        critic_loss = self.base_algorithm.criterion(pred_q, target_q)

        return Loss(critic_loss=critic_loss)
예제 #6
0
 def forward(self, observation, idx=None):
     """Compute losses at state/idx pairs."""
     state = observation.state
     if idx is None:
         idx = torch.arange(state.shape[0])
     return Loss(dual_loss=self.dual(observation, idx=idx) +
                 self.get_discount_dual_loss(state, idx))
예제 #7
0
    def forward(self, observation):
        """Compute the losses.

        Given an Observation, it will compute the losses.
        Given a list of Trajectories, it tries to stack them to vectorize operations.
        If it fails, will iterate over the trajectories.
        """
        if isinstance(observation, Observation):
            trajectories = [observation]
        elif len(observation) > 1:
            try:
                # When possible, stack to parallelize the trajectories.
                # This requires all trajectories to be equal of length.
                trajectories = [stack_list_of_tuples(observation)]
            except RuntimeError:
                trajectories = observation
        else:
            trajectories = observation

        self.reset_info()

        loss = Loss()
        for trajectory in trajectories:
            loss += self.actor_loss(trajectory)
            loss += self.critic_loss(trajectory)
            loss += self.regularization_loss(trajectory, len(trajectories))

        return loss / len(trajectories)
예제 #8
0
파일: reps.py 프로젝트: sebascuri/rllib
    def forward(self, action_log_p, value, target):
        """Return primal and dual loss terms from REPS.

        Parameters
        ----------
        action_log_p : torch.Tensor
            A [state_batch, 1] tensor of log probabilities of the corresponding actions
            under the policy.
        value: torch.Tensor
            The value function (with gradients) evaluated at V(s)
        target: torch.Tensor
            The value target (with gradients) evaluated at r + gamma V(s')
        """
        td = target - value
        weights = td / self._eta()
        normalizer = torch.logsumexp(weights, dim=0)
        dual_loss = self._eta() * (self.epsilon + normalizer)

        # Clamping is crucial for stability so that it does not converge to a delta.
        weighted_log_p = torch.exp(weights).clamp_max(
            1e2).detach() * action_log_p
        log_likelihood = weighted_log_p.mean()

        return Loss(policy_loss=-log_likelihood,
                    dual_loss=dual_loss,
                    td_error=td.mean())
예제 #9
0
 def forward(self, observation):
     """Rollout model and call base algorithm with transitions."""
     self.base_algorithm.reset_info()
     loss = Loss()
     loss += self.base_algorithm.actor_loss(observation)
     loss += self.model_augmented_critic_loss(observation)
     loss += self.base_algorithm.regularization_loss(observation)
     return loss
예제 #10
0
    def forward(self, observation):
        """Compute path-wise loss."""
        if self.policy is None or self.critic is None:
            return Loss()
        state = observation.state
        pi = tensor_to_distribution(self.policy(state),
                                    **self.policy.dist_params)
        action = self.policy.action_scale * pi.rsample().clamp(-1, 1)

        with DisableGradient(self.critic):
            q = self.critic(state, action)
            if isinstance(self.critic, NNEnsembleQFunction):
                q = q[..., 0]

        # Take mean over time coordinate.
        if q.dim() < 1:
            q = q.mean(dim=1)

        return Loss(policy_loss=-q)
예제 #11
0
파일: svg.py 프로젝트: sebimarkgraf/rllib
    def actor_loss(self, observation):
        """Use the model to compute the gradient loss."""
        state, action = observation.state, observation.action
        next_state, done = observation.next_state, observation.done

        # Infer eta.
        action_mean, action_chol = self.policy(state)
        with torch.no_grad():
            eta = torch.inverse(action_chol) @ (
                (action - action_mean).unsqueeze(-1))

        # Compute entropy and log_probability.
        pi = tensor_to_distribution((action_mean, action_chol))
        _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale)

        # Compute off-policy weight.
        with torch.no_grad():
            weight = self.get_ope_weight(state, action,
                                         observation.log_prob_action)

        with DisableGradient(
                self.dynamical_model,
                self.reward_model,
                self.termination_model,
                self.critic_target,
        ):
            # Compute re-parameterized policy sample.
            action = (action_mean + (action_chol @ eta).squeeze(-1)).clamp(
                -1, 1)

            # Infer xi.
            ns_mean, ns_chol = self.dynamical_model(state, action)
            with torch.no_grad():
                xi = torch.inverse(ns_chol) @ (
                    (next_state - ns_mean).unsqueeze(-1))

            # Compute re-parameterized next-state sample.
            ns = ns_mean + (ns_chol @ xi).squeeze(-1)

            # Compute reward.
            r = tensor_to_distribution(self.reward_model(state, action,
                                                         ns)).rsample()
            r = r[..., 0]

            next_v = self.value_function(ns)
            if isinstance(self.critic, NNEnsembleValueFunction) or isinstance(
                    self.critic, NNEnsembleQFunction):
                next_v = next_v[..., 0]

            v = r + self.gamma * next_v * (1 - done)

        return Loss(policy_loss=-(weight * v)).reduce(self.criterion.reduction)
예제 #12
0
파일: steve.py 프로젝트: sebimarkgraf/rllib
    def model_augmented_critic_loss(self, observation):
        """Get Model-Based critic-loss."""
        pred_q = self.base_algorithm.get_value_prediction(observation)

        # Get target_q with semi-gradients.
        with torch.no_grad():
            target_q = self.get_value_target(observation)
            if pred_q.shape != target_q.shape:  # Reshape in case of ensembles.
                assert isinstance(self.critic, NNEnsembleQFunction)
                target_q = target_q.unsqueeze(-1).repeat_interleave(
                    self.critic.num_heads, -1)

        critic_loss = self.base_algorithm.criterion(pred_q, target_q)

        return Loss(critic_loss=critic_loss)
예제 #13
0
    def score_actor_loss(self, observation, linearized=False):
        """Get score actor loss for policy gradients."""
        state, action, reward, next_state, done, *r = observation

        log_p, ratio = self.get_log_p_and_ope_weight(state, action)

        with torch.no_grad():
            adv = self.returns(observation)
            if self.standardize_returns:
                adv = (adv - adv.mean()) / (adv.std() + self.eps)

        if linearized:
            score = ratio * adv
        else:
            score = discount_sum(log_p * adv, self.gamma)

        return Loss(policy_loss=-score)
예제 #14
0
    def actor_loss(self, observation):
        """Get Actor loss."""
        state, action, *_ = observation

        pi = tensor_to_distribution(self.policy(state),
                                    **self.policy.dist_params)
        entropy, _ = get_entropy_and_log_p(pi, action,
                                           self.policy.action_scale)

        policy_loss = integrate(
            lambda a: -pi.log_prob(a) *
            (self.critic(state, self.policy.action_scale * a) - self.
             value_target(state)).detach(),
            pi,
            num_samples=self.num_samples,
        ).sum()

        return Loss(policy_loss=policy_loss).reduce(self.criterion.reduction)
예제 #15
0
파일: reps.py 프로젝트: sebascuri/rllib
    def actor_loss(self, observation):
        """Return primal and dual loss terms from REPS."""
        state, action, reward, next_state, done, *r = observation

        # Compute Scaled TD-Errors
        value = self.critic(state)

        # For dual function we need the full gradient, not the semi gradient!
        target = self.get_value_target(observation)

        pi = tensor_to_distribution(self.policy(state),
                                    **self.policy.dist_params)
        _, action_log_p = get_entropy_and_log_p(pi, action,
                                                self.policy.action_scale)

        reps_loss = self.reps_loss(action_log_p, value, target)
        self._info.update(reps_eta=self.reps_loss.eta)
        return reps_loss + Loss(dual_loss=(1.0 - self.gamma) * value.mean())
예제 #16
0
    def actor_loss(self, trajectory):
        """Get actor loss."""
        state, action, reward, next_state, done, *r = trajectory
        log_p, ratio = self.get_log_p_and_ope_weight(state, action)

        with torch.no_grad():
            adv = self.returns(trajectory)
            if self.standardize_returns:
                adv = (adv - adv.mean()) / (adv.std() + self.eps)

        # Compute surrogate loss.
        weighted_advantage = ratio * adv
        clipped_advantage = ratio.clamp(1 - self.epsilon(),
                                        1 + self.epsilon()) * adv
        surrogate_loss = -torch.min(weighted_advantage, clipped_advantage)
        # Instead of using the Trust-region, TRPO takes the minimum in line 80.

        return Loss(policy_loss=surrogate_loss).reduce(
            self.criterion.reduction)
예제 #17
0
파일: reps.py 프로젝트: sebimarkgraf/rllib
    def actor_loss(self, observation):
        """Return primal and dual loss terms from REPS."""
        state, action, reward, next_state, done, *r = observation

        # Compute Scaled TD-Errors
        value = self.critic(state)

        # For dual function we need the full gradient, not the semi gradient!
        target = self.get_value_target(observation)
        td = target - value

        weights = td / self.eta()
        normalizer = torch.logsumexp(weights, dim=0)
        dual = self.eta() * (self.epsilon + normalizer) + (1.0 -
                                                           self.gamma) * value

        nll = self._policy_weighted_nll(state, action, weights)

        return Loss(dual_loss=dual.mean(), policy_loss=nll, td_error=td)
예제 #18
0
    def actor_loss(self, observation):
        """Return primal and dual loss terms from Q-REPS."""
        state, action, reward, next_state, done, *r = observation

        # Calculate dual variables
        value = self.critic(state)
        target = self.get_value_target(observation)
        q_value = self.q_function(state, action)

        td = target - q_value
        self._info.update(td=td)

        # Calculate weights.
        weights_td = self.eta() * td  # type: torch.Tensor
        if weights_td.ndim == 1:
            weights_td = weights_td.unsqueeze(-1)
        dual = 1 / self.eta() * torch.logsumexp(weights_td, dim=-1)
        dual += (1 - self.gamma) * value.squeeze(-1)
        return Loss(dual_loss=dual.mean(), td_error=td)
예제 #19
0
    def actor_loss(self, observation):
        """Get actor loss.

        This is different for each algorithm.

        Parameters
        ----------
        observation: Observation.
            Sampled observations.
            It is of shape B x N x d, where:
                - B is the batch size
                - N is the N-step return
                - d is the dimension of the attribute.

        Returns
        -------
        loss: Loss.
            Loss with parameters loss, policy_loss, and regularization_loss filled.
        """
        return Loss()
예제 #20
0
    def actor_loss(self, observation) -> Loss:
        """Compute Actor loss."""
        state, action = observation.state[..., 0, :], observation.action[...,
                                                                         0, :]
        action_mean, action_chol = self.policy(state)

        # Infer eta.
        with torch.no_grad():
            delta = action / self.policy.action_scale - action_mean
            eta = torch.inverse(action_chol) @ delta.unsqueeze(-1)

        # Compute re-parameterized policy sample.
        action = self.policy.action_scale * (
            action_mean + (action_chol @ eta).squeeze(-1)).clamp(-1.0, 1.0)

        # Propagate gradient.
        with DisableGradient(self.critic):
            q = self.critic(observation.state[..., 0, :], action)
            if isinstance(self.critic, NNEnsembleQFunction):
                q = q[..., 0]

        return Loss(policy_loss=-q).reduce(self.criterion.reduction)
예제 #21
0
파일: mbmpo.py 프로젝트: sebascuri/hucrl
    def actor_loss(self, observation):
        """Compute the losses for one step of MPO.

        Parameters
        ----------
        observation : Observation
            The states at which to compute the losses.
        """
        state = observation.state
        value_prediction = self.critic(state)

        with torch.no_grad():
            value_estimate, obs = mb_return(
                state=state,
                dynamical_model=self.dynamical_model,
                policy=self.old_policy,
                reward_model=self.reward_model,
                num_steps=1,
                gamma=self.gamma,
                value_function=self.critic_target,
                num_samples=self.num_samples,
                reward_transformer=self.reward_transformer,
                termination_model=self.termination_model,
                reduction="min",
            )
        q_values = value_estimate
        log_p, _ = self.get_log_p_and_ope_weight(obs.state, obs.action)

        # Since actions come from policy, value is the expected q-value
        mpo_loss = self.mpo_loss(q_values=q_values,
                                 action_log_p=log_p.squeeze(-1))
        value_loss = self.criterion(value_prediction, q_values.mean(dim=0))
        td_error = value_prediction - q_values.mean(dim=0)

        critic_loss = Loss(critic_loss=value_loss, td_error=td_error)
        self._info.update(eta=self.mpo_loss.eta)
        return mpo_loss.reduce(self.criterion.reduction) + critic_loss
예제 #22
0
파일: mpo.py 프로젝트: sebimarkgraf/rllib
    def forward(self, q_values, action_log_p):
        """Return primal and dual loss terms from MPO.

        Parameters
        ----------
        q_values : torch.Tensor
            A [n_action_samples, state_batch, 1] tensor of values for
            state-action pairs.
        action_log_p : torch.Tensor
            A [n_action_samples, state_batch, 1] tensor of log probabilities
            of the corresponding actions under the policy.
        """
        # Make sure the lagrange multipliers stay positive.
        # self.project_etas()

        # E-step: Solve Problem (7).
        # Create a weighed, sample-based representation of the optimal policy q Eq(8).
        # Compute the dual loss for the constraint KL(q || old_pi) < eps.
        q_values = q_values.detach() * (torch.tensor(1.0) / self._eta())
        normalizer = torch.logsumexp(q_values, dim=0)
        num_actions = torch.tensor(1.0 * action_log_p.shape[0])

        dual_loss = self._eta() * (
            self.epsilon + torch.mean(normalizer) - torch.log(num_actions)
        )
        # non-parametric representation of the optimal policy.
        weights = torch.exp(q_values - normalizer.detach())

        # M-step: # E-step: Solve Problem (10).
        # Fit the parametric policy to the representation form the E-step.
        # Maximize the log_likelihood of the weighted log probabilities, subject to the
        # KL divergence between the old_pi and the new_pi to be smaller than epsilon.

        weighted_log_p = torch.sum(weights * action_log_p, dim=0)
        log_likelihood = weighted_log_p

        return Loss(policy_loss=-log_likelihood.mean(), dual_loss=dual_loss)
예제 #23
0
파일: mbmpo.py 프로젝트: sebascuri/hucrl
 def critic_loss(self, observation):
     """Return 0 loss. The actor loss returns both the critic and the actor."""
     return Loss()
예제 #24
0
파일: reps.py 프로젝트: sebascuri/rllib
 def critic_loss(self, observation) -> Loss:
     """Get the critic loss."""
     return Loss()