Example #1
0
 def get_log_p_and_ope_weight(self, state, action):
     """Get log_p of a state-action and the off-pol weight w.r.t. the old policy."""
     pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params)
     pi_o = tensor_to_distribution(self.old_policy(state), **self.policy.dist_params)
     _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale)
     _, log_p_old = get_entropy_and_log_p(pi_o, action, self.policy.action_scale)
     ratio = torch.exp(log_p - log_p_old)
     return log_p, ratio
Example #2
0
    def get_ope_weight(self, state, action, log_prob_action):
        """Get off-policy weight of a given transition."""
        pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params)
        _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale)

        weight = off_policy_weight(log_p, log_prob_action, full_trajectory=False)
        return weight
Example #3
0
def step_env(environment, state, action, action_scale, pi=None, render=False):
    """Perform a single step in an environment."""
    try:
        next_state, reward, done, info = environment.step(action)
    except TypeError:
        next_state, reward, done, info = environment.step(action.item())

    if not isinstance(action, torch.Tensor):
        action = torch.tensor(action, dtype=torch.get_default_dtype())

    if pi is not None:
        try:
            with torch.no_grad():
                entropy, log_prob_action = get_entropy_and_log_p(
                    pi, action, action_scale
                )
        except RuntimeError:
            entropy, log_prob_action = 0.0, 1.0
    else:
        entropy, log_prob_action = 0.0, 1.0

    observation = Observation(
        state=state,
        action=action,
        reward=reward,
        next_state=next_state,
        done=done,
        entropy=entropy,
        log_prob_action=log_prob_action,
    ).to_torch()
    state = next_state
    if render:
        environment.render()
    return observation, state, done, info
Example #4
0
    def get_kl_entropy(self, state):
        """Get kl divergence and current policy at a given state.

        Compute the separated KL divergence between current and old policy.
        When the policy is a MultivariateNormal distribution, it compute the divergence
        that correspond to the mean and the covariance separately.

        When the policy is a Categorical distribution, it computes the divergence and
        assigns it to the mean component. The variance component is kept to zero.

        Parameters
        ----------
        state: torch.Tensor
            Empirical state distribution.

        Returns
        -------
        kl_mean: torch.Tensor
            KL-Divergence due to the change in the mean between current and
            previous policy.
        kl_var: torch.Tensor
            KL-Divergence due to the change in the variance between current and
            previous policy.
        entropy: torch.Tensor
            Entropy of the current policy at the given state.
        """
        pi = tensor_to_distribution(self.policy(state), **self.policy.dist_params)
        pi_old = tensor_to_distribution(
            self.old_policy(state), **self.policy.dist_params
        )
        try:
            action = pi.rsample()
        except NotImplementedError:
            action = pi.sample()
        if not self.policy.discrete_action:
            action = self.policy.action_scale * (action.clamp(-1.0, 1.0))

        entropy, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale)
        _, log_p_old = get_entropy_and_log_p(pi_old, action, self.policy.action_scale)

        kl_mean, kl_var = separated_kl(p=pi_old, q=pi, log_p=log_p_old, log_q=log_p)

        return kl_mean, kl_var, entropy
Example #5
0
    def _policy_weighted_nll(self, state, action, weights):
        """Return weighted policy negative log-likelihood."""
        pi = tensor_to_distribution(self.policy(state),
                                    **self.policy.dist_params)
        _, action_log_p = get_entropy_and_log_p(pi, action,
                                                self.policy.action_scale)
        weighted_log_p = weights.detach() * action_log_p

        # Clamping is crucial for stability so that it does not converge to a delta.
        log_likelihood = torch.mean(weighted_log_p.clamp_max(1e-3))
        return -log_likelihood
Example #6
0
    def actor_loss(self, observation):
        """Use the model to compute the gradient loss."""
        state, action = observation.state, observation.action
        next_state, done = observation.next_state, observation.done

        # Infer eta.
        action_mean, action_chol = self.policy(state)
        with torch.no_grad():
            eta = torch.inverse(action_chol) @ (
                (action - action_mean).unsqueeze(-1))

        # Compute entropy and log_probability.
        pi = tensor_to_distribution((action_mean, action_chol))
        _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale)

        # Compute off-policy weight.
        with torch.no_grad():
            weight = self.get_ope_weight(state, action,
                                         observation.log_prob_action)

        with DisableGradient(
                self.dynamical_model,
                self.reward_model,
                self.termination_model,
                self.critic_target,
        ):
            # Compute re-parameterized policy sample.
            action = (action_mean + (action_chol @ eta).squeeze(-1)).clamp(
                -1, 1)

            # Infer xi.
            ns_mean, ns_chol = self.dynamical_model(state, action)
            with torch.no_grad():
                xi = torch.inverse(ns_chol) @ (
                    (next_state - ns_mean).unsqueeze(-1))

            # Compute re-parameterized next-state sample.
            ns = ns_mean + (ns_chol @ xi).squeeze(-1)

            # Compute reward.
            r = tensor_to_distribution(self.reward_model(state, action,
                                                         ns)).rsample()
            r = r[..., 0]

            next_v = self.value_function(ns)
            if isinstance(self.critic, NNEnsembleValueFunction) or isinstance(
                    self.critic, NNEnsembleQFunction):
                next_v = next_v[..., 0]

            v = r + self.gamma * next_v * (1 - done)

        return Loss(policy_loss=-(weight * v)).reduce(self.criterion.reduction)
Example #7
0
    def forward(self, observation):
        """Compute the loss and the td-error."""
        state, action, reward, next_state, done, *_ = observation
        behavior_log_p = observation.log_prob_action
        n_steps = state.shape[1]

        # done_t indicates if the current state is done.
        done_t = torch.cat((torch.zeros(done.shape[0], 1), done), -1)[:, :-1]

        # Compute off-policy correction factor.
        if self.policy is not None:
            pi = tensor_to_distribution(self.policy(state),
                                        **self.policy.dist_params)
            _, log_p = get_entropy_and_log_p(pi, action,
                                             self.policy.action_scale)
        else:
            log_p = behavior_log_p
        correction = self.correction(log_p, behavior_log_p)

        # Compute Q(state, action) and \E_\pi[Q(next_state, \pi(next_state)].
        if isinstance(self.critic, AbstractValueFunction):
            this_v = self.critic(state) * (1.0 - done_t)
            next_v = self.critic(next_state)
        else:
            this_v = self.critic(state, action) * (1.0 - done_t)

            if self.policy is not None:
                next_v = self.value_target(next_state)
            else:
                next_v = self.critic(next_state[:, :n_steps - 1], action[:,
                                                                         1:])
                last_v = torch.zeros(next_v.shape[0], 1)
                if last_v.ndim < next_v.ndim:
                    last_v = last_v.unsqueeze(-1).repeat_interleave(
                        next_v.shape[-1], -1)
                next_v = torch.cat((next_v, last_v), -1)
        next_v = next_v * (1.0 - done)
        # Compute td = r + gamma E\pi[Q(next_state, \pi(next_state)] - Q(state, action).
        td = self.td(this_v, next_v, reward, correction)

        # Compute correction factor_t = \Prod_{i=1,t} c_i.
        correction_factor = torch.cumprod(correction, dim=-1)

        # Compute discount_t = \gamma ** (t-1)
        discount = torch.pow(torch.tensor(self.gamma), torch.arange(n_steps))

        # Compute target = Q(s, a) + \sum_{i=1,t} discount_i factor_i td_i. See RETRACE.
        target = this_v + reverse_cumsum(td * discount * correction_factor)

        return target
Example #8
0
    def actor_loss(self, observation):
        """Get Actor loss."""
        state, action, *_ = observation

        pi = tensor_to_distribution(self.policy(state),
                                    **self.policy.dist_params)
        entropy, _ = get_entropy_and_log_p(pi, action,
                                           self.policy.action_scale)

        policy_loss = integrate(
            lambda a: -pi.log_prob(a) *
            (self.critic(state, self.policy.action_scale * a) - self.
             value_target(state)).detach(),
            pi,
            num_samples=self.num_samples,
        ).sum()

        return Loss(policy_loss=policy_loss).reduce(self.criterion.reduction)
Example #9
0
    def actor_loss(self, observation):
        """Return primal and dual loss terms from REPS."""
        state, action, reward, next_state, done, *r = observation

        # Compute Scaled TD-Errors
        value = self.critic(state)

        # For dual function we need the full gradient, not the semi gradient!
        target = self.get_value_target(observation)

        pi = tensor_to_distribution(self.policy(state),
                                    **self.policy.dist_params)
        _, action_log_p = get_entropy_and_log_p(pi, action,
                                                self.policy.action_scale)

        reps_loss = self.reps_loss(action_log_p, value, target)
        self._info.update(reps_eta=self.reps_loss.eta)
        return reps_loss + Loss(dual_loss=(1.0 - self.gamma) * value.mean())
Example #10
0
def step_model(
    dynamical_model,
    reward_model,
    termination_model,
    state,
    action,
    done=None,
    action_scale=1.0,
    pi=None,
):
    """Perform a single step in an dynamical model."""
    # Sample a next state
    next_state_out = dynamical_model(state, action)
    next_state_distribution = tensor_to_distribution(next_state_out)

    if next_state_distribution.has_rsample:
        next_state = next_state_distribution.rsample()
    else:
        next_state = next_state_distribution.sample()

    # Sample a reward
    reward_distribution = tensor_to_distribution(
        reward_model(state, action, next_state)
    )
    if reward_distribution.has_rsample:
        reward = reward_distribution.rsample().squeeze(-1)
    else:
        reward = reward_distribution.sample().squeeze(-1)
    if done is None:
        done = torch.zeros_like(reward).bool()
    reward *= (~done).float()

    # Check for termination.
    if termination_model is not None:
        done = done + (  # "+" is a boolean "or".
            tensor_to_distribution(termination_model(state, action, next_state))
            .sample()
            .bool()
        )

    if pi is not None:
        try:
            entropy, log_prob_action = get_entropy_and_log_p(pi, action, action_scale)
        except RuntimeError:
            entropy, log_prob_action = 0.0, 1.0
    else:
        entropy, log_prob_action = 0.0, 1.0

    observation = Observation(
        state=state,
        action=action,
        reward=reward,
        next_state=next_state,
        done=done.float(),
        entropy=entropy,
        log_prob_action=log_prob_action,
        next_state_scale_tril=next_state_out[-1],
    ).to_torch()

    # Update state.
    next_state = torch.zeros_like(state)
    next_state[~done] = observation.next_state[~done]  # update next state.
    next_state[done] = state[done]  # don't update next state.

    return observation, next_state, done