Exemple #1
0
    def surrogate_loss(self, episodes, old_pi=None):
        with torch.set_grad_enabled(old_pi is None):
            if self.baseline is None:
                advantages = episodes.returns
            else:
                self.baseline.fit(episodes)
                values = self.baseline(episodes)
                advantages = episodes.gae(values, tau=self.tau)
                advantages = weighted_normalize(advantages,
                                                weights=episodes.mask)

            pi = self.policy(episodes.observations)
            if old_pi is None: old_pi = detach_distribution(pi)

            log_ratio = pi.log_prob(episodes.actions) - old_pi.log_prob(
                episodes.actions)
            if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2)
            ratio = torch.exp(log_ratio)
            loss = -weighted_mean(ratio * advantages, weights=episodes.mask)

            mask = episodes.mask
            if episodes.actions.dim() > 2: mask = mask.unsqueeze(2)
            kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask)

        return loss, kl, pi
Exemple #2
0
    def step(self, episodes, clip=False, recurrent=False, seq_len=5):
        if self.baseline is None:
            returns = episodes.returns
        else:
            self.baseline.fit(episodes)
            values = self.baseline(episodes)
            advantages = episodes.gae(values, tau=self.tau)
            returns = weighted_normalize(advantages, weights=episodes.mask)

        if recurrent:
            ### (time_horizon, batch_size, state_dim)
            obs = episodes.observations
            log_probs = []
            for idx in range(obs.size(0)):
                if idx < seq_len: obs_seq = obs[:idx + 1]
                else: obs_seq = obs[-seq_len + idx + 1:idx + 1]

                pi = self.policy(obs_seq)
                log_prob = pi.log_prob(episodes.actions[idx])
                log_probs.append(log_prob)
            log_probs = torch.stack(log_probs, axis=0)
        else:
            pi = self.policy(episodes.observations)
            log_probs = pi.log_prob(episodes.actions)

        if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2)

        loss = -weighted_mean(log_probs * returns, weights=episodes.mask)

        self.opt.zero_grad()
        loss.backward()
        if clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
        self.opt.step()
Exemple #3
0
 def kl_divergence(self, episodes, old_pi=None):
     pi = self.policy(episodes.observations)
     if old_pi is None:
         old_pi = detach_distribution(pi)
     mask = episodes.mask 
     if episodes.actions.dim() > 2:
         mask = mask.unsqueeze(2)
     kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask)
     return kl 
def reinforce_loss(policy, episodes, params=None):
    pi = policy(episodes.observations.view((-1, *episodes.observation_shape)),
                params=params)

    log_probs = pi.log_prob(episodes.actions.view(
        (-1, *episodes.action_shape)))
    log_probs = log_probs.view(len(episodes), episodes.batch_size)

    losses = -weighted_mean(log_probs * episodes.advantages,
                            lengths=episodes.lengths)

    return losses.mean()
Exemple #5
0
    def surrogate_loss(self, episodes, old_pis=None):
        losses, kls, pis = [], [], []
        if old_pis is None: old_pis = [None] * len(episodes)

        for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis):
            params = self.adapt(train_episodes)
            with torch.set_grad_enabled(old_pi is None):
                pi = self.policy(valid_episodes.observations, params=params)
                pis.append(detach_distribution(pi))

                if old_pi is None: old_pi = detach_distribution(pi)

                if self.baseline is None:
                    advantages = valid_episodes.returns
                else:
                    self.baseline.fit(valid_episodes)
                    values = self.baseline(valid_episodes)
                    advantages = valid_episodes.gae(values, tau=self.tau)
                    advantages = weighted_normalize(
                        advantages, weights=valid_episodes.mask)

                log_ratio = (pi.log_prob(valid_episodes.actions) \
                        - old_pi.log_prob(valid_episodes.actions))
                if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2)
                ratio = torch.exp(log_ratio)

                loss = -weighted_mean(ratio * advantages,
                                      weights=valid_episodes.mask)
                losses.append(loss)

                mask = valid_episodes.mask
                if valid_episodes.actions.dim() > 2: mask = mask.unsqueeze(2)
                kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask)
                kls.append(kl)

        return (torch.mean(torch.stack(losses, dim=0)),
                torch.mean(torch.stack(kls, dim=0)), pis)
Exemple #6
0
    def step_ppo(self, episodes, epochs=5, ppo_clip=0.2, clip=False):
        advantages, old_pis = [], []
        for (train_episodes, valid_episodes) in episodes:
            if self.baseline is None:
                advantage = valid_episodes.returns
            else:
                self.baseline.fit(valid_episodes)
                values = self.baseline(valid_episodes)
                advantage = valid_episodes.gae(values, tau=self.tau)
                advantage = weighted_normalize(advantage,
                                               weights=valid_episodes.mask)
            advantages.append(advantage)

            params = self.adapt(train_episodes)
            pi = self.policy(valid_episodes.observations, params=params)
            old_pis.append(detach_distribution(pi))

        for epoch in range(epochs):
            losses, clipped_losses, masks = [], [], []
            for idx, (train_episodes, valid_episodes) in enumerate(episodes):
                params = self.adapt(train_episodes)
                pi = self.policy(valid_episodes.observations, params=params)

                log_ratio = pi.log_prob(valid_episodes.actions) - \
                        old_pis[idx].log_prob(valid_episodes.actions)
                if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2)
                ratio = torch.exp(log_ratio)
                loss = advantages[idx] * ratio

                clipped_ratio = torch.clamp(ratio, 1.0 - ppo_clip,
                                            1.0 + ppo_clip)
                clipped_loss = advantages[idx] * clipped_ratio

                losses.append(loss)
                clipped_losses.append(clipped_loss)
                masks.append(valid_episodes.mask)

            losses = torch.cat(losses, dim=0)
            clipped_losses = torch.cat(clipped_losses, dim=0)
            masks = torch.cat(masks, dim=0)

            total_loss = -weighted_mean(torch.min(losses, clipped_losses),
                                        weights=masks)

            self.opt.zero_grad()
            total_loss.backward()
            if clip:
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
            self.opt.step()
Exemple #7
0
    def kl_divergence(self, episodes, old_pis=None):
        kls = []
        if old_pis is None: old_pis = [None] * len(episodes)

        for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis):
            params = self.adapt(train_episodes)
            pi = self.policy(valid_episodes.observations, params=params)

            if old_pi is None: old_pi = detach_distribution(pi)

            mask = valid_episodes.mask
            if valid_episodes.actions.dim() > 2: mask = mask.unsqueeze(2)
            kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask)
            kls.append(kl)

        return torch.mean(torch.stack(kls, dim=0))
Exemple #8
0
    def surrogate_loss(self, episodes, old_pi=None, pr=False, iw=False):
        with torch.set_grad_enabled(old_pi is None):
            pi = self.policy(episodes.observations)
            if old_pi is None:
                old_pi = detach_distribution(pi)
            
            advantages = episodes.returns 
            log_ratio = pi.log_prob(episodes.actions) - old_pi.log_prob(episodes.actions)
            if log_ratio.dim() > 2:
                log_ratio = torch.sum(log_ratio, dim=2)
            ratio = torch.exp(log_ratio)

            ### apply importance sampling 
            if pr:
                log_probs = pi.log_prob(episodes.actions)
                if log_probs.dim() > 2:
                    log_probs = torch.sum(log_probs, dim=2)
                log_probs_old = episodes.log_probs 
                if log_probs_old.dim() > 2:
                    log_probs_old = torch.sum(log_probs_old, dim=2)
                ### compute p(x)/q(x), estimate p(x) under samples from q(x)
                importance_ = torch.ones(episodes.rewards.size(1)).to(self.device)
                importances = []
                for log_prob, log_prob_old, mask in zip(log_probs, log_probs_old,
                        episodes.mask):
                    importance_ = importance_*torch.div(
                            log_prob.exp()*mask + self.pr_smooth, 
                            log_prob_old.exp()*mask + self.pr_smooth)
                    #importance_ = norm_01(importance_)
                    importance_ = weighted_normalize(importance_)
                    importance_ = importance_ - importance_.min()
                    importances.append(importance_)
                importances = torch.stack(importances, dim=0)
                importances = importances.detach()

            ### apply importance weighting
            if iw:
                weights = torch.sum(episodes.returns*episodes.mask,
                        dim=0)/torch.sum(episodes.mask,dim=0)
                ### if the rewards are negtive, then use "rmax - r" as the weighting metric
                if self.iw_inv:
                    weights = weights.max() - weights  
                weights = weighted_normalize(weights)
                weights = weights - weights.min()

            if pr and iw:
                t_loss = torch.sum(
                        importances * advantages * ratio * episodes.mask,
                        dim=0)/torch.sum(episodes.mask, dim=0)
                loss = -torch.mean(weights*t_loss)
            elif pr and not iw:
                loss = -weighted_mean(ratio * advantages * importances, 
                        weights=episodes.mask)
            elif not pr and iw:
                t_loss = torch.sum(advantages * ratio * episodes.mask,
                        dim=0)/torch.sum(episodes.mask,dim=0)
                loss = -torch.mean(weights*t_loss)
            else:
                loss = -weighted_mean(ratio * advantages, weights=episodes.mask)

            mask = episodes.mask 
            if episodes.actions.dim() > 2:
                mask = mask.unsqueeze(2)
            kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask)

        return loss, kl, pi