def inner_loss(self, episodes, pr=False, iw=False): returns = episodes.returns pi = self.policy(episodes.observations) log_probs = pi.log_prob(episodes.actions) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) ### apply importance sampling / policy relaxation if pr: log_probs_old = episodes.log_probs if log_probs_old.dim() > 2: log_probs_old = torch.sum(log_probs_old, dim=2) ### compute p(x)/q(x), estimate p(x) under samples from q(x) importance_ = torch.ones(episodes.rewards.size(1)).to(self.device) importances = [] for log_prob, log_prob_old, mask in zip(log_probs, log_probs_old, episodes.mask): importance_ = importance_ * torch.div( log_prob.exp() * mask + self.pr_smooth, log_prob_old.exp() * mask + self.pr_smooth) importance_ = weighted_normalize(importance_) importance_ = importance_ - importance_.min() importances.append(importance_) importances = torch.stack(importances, dim=0) importances = importances.detach() ### apply importance weighting if iw: weights = torch.sum(returns * episodes.mask, dim=0) / torch.sum( episodes.mask, dim=0) weights = weights.max() - weights weights = weighted_normalize(weights) weights = weights - weights.min() if pr and iw: ### the proposed method, PR + IW t_loss = torch.sum( importances * returns * log_probs * episodes.mask, dim=0) / torch.sum(episodes.mask, dim=0) loss = -torch.mean(weights * t_loss) elif pr and not iw: ### only apply Policy Relaxation loss = -weighted_mean(log_probs * returns * importances, weights=episodes.mask) elif not pr and iw: ### only apply Importance Weighting t_loss = torch.sum(returns * log_probs * episodes.mask, dim=0) / torch.sum(episodes.mask, dim=0) loss = -torch.mean(weights * t_loss) else: ### the baseline loss = -weighted_mean(log_probs * returns, weights=episodes.mask) return loss
def step(self, episodes, clip=False, recurrent=False, seq_len=5): if self.baseline is None: returns = episodes.returns else: self.baseline.fit(episodes) values = self.baseline(episodes) advantages = episodes.gae(values, tau=self.tau) returns = weighted_normalize(advantages, weights=episodes.mask) if recurrent: ### (time_horizon, batch_size, state_dim) obs = episodes.observations log_probs = [] for idx in range(obs.size(0)): if idx < seq_len: obs_seq = obs[:idx + 1] else: obs_seq = obs[-seq_len + idx + 1:idx + 1] pi = self.policy(obs_seq) log_prob = pi.log_prob(episodes.actions[idx]) log_probs.append(log_prob) log_probs = torch.stack(log_probs, axis=0) else: pi = self.policy(episodes.observations) log_probs = pi.log_prob(episodes.actions) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) loss = -weighted_mean(log_probs * returns, weights=episodes.mask) self.opt.zero_grad() loss.backward() if clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0) self.opt.step()
def step(self, episodes): losses = [] for (train_episodes, valid_episodes) in episodes: if self.baseline is None: advantages = valid_episodes.returns else: self.baseline.fit(valid_episodes) values = self.baseline(valid_episodes) advantages = valid_episodes.gae(values, tau=self.tau) advantages = weighted_normalize(advantages, weights=valid_episodes.mask) params = self.adapt(train_episodes) pi = self.policy(valid_episodes.observations, params=params) log_probs = pi.log_prob(valid_episodes.actions) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) loss = -weighted_mean(log_probs * advantages, weights=valid_episodes.mask) losses.append(loss) total_loss = torch.mean(torch.stack(losses, dim=0)) self.opt.zero_grad() total_loss.backward() self.opt.step()
def surrogate_loss(self, episodes, old_pi=None): with torch.set_grad_enabled(old_pi is None): if self.baseline is None: advantages = episodes.returns else: self.baseline.fit(episodes) values = self.baseline(episodes) advantages = episodes.gae(values, tau=self.tau) advantages = weighted_normalize(advantages, weights=episodes.mask) pi = self.policy(episodes.observations) if old_pi is None: old_pi = detach_distribution(pi) log_ratio = pi.log_prob(episodes.actions) - old_pi.log_prob( episodes.actions) if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2) ratio = torch.exp(log_ratio) loss = -weighted_mean(ratio * advantages, weights=episodes.mask) mask = episodes.mask if episodes.actions.dim() > 2: mask = mask.unsqueeze(2) kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask) return loss, kl, pi
def step_ppo(self, episodes, epochs=5, ppo_clip=0.2, clip=False): advantages, old_pis = [], [] for (train_episodes, valid_episodes) in episodes: if self.baseline is None: advantage = valid_episodes.returns else: self.baseline.fit(valid_episodes) values = self.baseline(valid_episodes) advantage = valid_episodes.gae(values, tau=self.tau) advantage = weighted_normalize(advantage, weights=valid_episodes.mask) advantages.append(advantage) params = self.adapt(train_episodes) pi = self.policy(valid_episodes.observations, params=params) old_pis.append(detach_distribution(pi)) for epoch in range(epochs): losses, clipped_losses, masks = [], [], [] for idx, (train_episodes, valid_episodes) in enumerate(episodes): params = self.adapt(train_episodes) pi = self.policy(valid_episodes.observations, params=params) log_ratio = pi.log_prob(valid_episodes.actions) - \ old_pis[idx].log_prob(valid_episodes.actions) if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2) ratio = torch.exp(log_ratio) loss = advantages[idx] * ratio clipped_ratio = torch.clamp(ratio, 1.0 - ppo_clip, 1.0 + ppo_clip) clipped_loss = advantages[idx] * clipped_ratio losses.append(loss) clipped_losses.append(clipped_loss) masks.append(valid_episodes.mask) losses = torch.cat(losses, dim=0) clipped_losses = torch.cat(clipped_losses, dim=0) masks = torch.cat(masks, dim=0) total_loss = -weighted_mean(torch.min(losses, clipped_losses), weights=masks) self.opt.zero_grad() total_loss.backward() if clip: torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0) self.opt.step()
def adapt(self, episodes, first_order=False): if self.baseline is None: advantages = episodes.returns else: self.baseline.fit(episodes) values = self.baseline(episodes) advantages = episodes.gae(values, tau=self.tau) advantages = weighted_normalize(advantages, weights=episodes.mask) pi = self.policy(episodes.observations) log_probs = pi.log_prob(episodes.actions) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) loss = -weighted_mean(log_probs * advantages, weights=episodes.mask) # Get the new parameters after a one-step gradient update params = self.policy.update_params(loss, step_size=self.fast_lr, first_order=first_order) return params
def surrogate_loss(self, episodes, old_pis=None): losses, kls, pis = [], [], [] if old_pis is None: old_pis = [None] * len(episodes) for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis): params = self.adapt(train_episodes) with torch.set_grad_enabled(old_pi is None): pi = self.policy(valid_episodes.observations, params=params) pis.append(detach_distribution(pi)) if old_pi is None: old_pi = detach_distribution(pi) if self.baseline is None: advantages = valid_episodes.returns else: self.baseline.fit(valid_episodes) values = self.baseline(valid_episodes) advantages = valid_episodes.gae(values, tau=self.tau) advantages = weighted_normalize( advantages, weights=valid_episodes.mask) log_ratio = (pi.log_prob(valid_episodes.actions) \ - old_pi.log_prob(valid_episodes.actions)) if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2) ratio = torch.exp(log_ratio) loss = -weighted_mean(ratio * advantages, weights=valid_episodes.mask) losses.append(loss) mask = valid_episodes.mask if valid_episodes.actions.dim() > 2: mask = mask.unsqueeze(2) kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask) kls.append(kl) return (torch.mean(torch.stack(losses, dim=0)), torch.mean(torch.stack(kls, dim=0)), pis)
def surrogate_loss(self, episodes, old_pi=None, pr=False, iw=False): with torch.set_grad_enabled(old_pi is None): pi = self.policy(episodes.observations) if old_pi is None: old_pi = detach_distribution(pi) advantages = episodes.returns log_ratio = pi.log_prob(episodes.actions) - old_pi.log_prob(episodes.actions) if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2) ratio = torch.exp(log_ratio) ### apply importance sampling if pr: log_probs = pi.log_prob(episodes.actions) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) log_probs_old = episodes.log_probs if log_probs_old.dim() > 2: log_probs_old = torch.sum(log_probs_old, dim=2) ### compute p(x)/q(x), estimate p(x) under samples from q(x) importance_ = torch.ones(episodes.rewards.size(1)).to(self.device) importances = [] for log_prob, log_prob_old, mask in zip(log_probs, log_probs_old, episodes.mask): importance_ = importance_*torch.div( log_prob.exp()*mask + self.pr_smooth, log_prob_old.exp()*mask + self.pr_smooth) #importance_ = norm_01(importance_) importance_ = weighted_normalize(importance_) importance_ = importance_ - importance_.min() importances.append(importance_) importances = torch.stack(importances, dim=0) importances = importances.detach() ### apply importance weighting if iw: weights = torch.sum(episodes.returns*episodes.mask, dim=0)/torch.sum(episodes.mask,dim=0) ### if the rewards are negtive, then use "rmax - r" as the weighting metric if self.iw_inv: weights = weights.max() - weights weights = weighted_normalize(weights) weights = weights - weights.min() if pr and iw: t_loss = torch.sum( importances * advantages * ratio * episodes.mask, dim=0)/torch.sum(episodes.mask, dim=0) loss = -torch.mean(weights*t_loss) elif pr and not iw: loss = -weighted_mean(ratio * advantages * importances, weights=episodes.mask) elif not pr and iw: t_loss = torch.sum(advantages * ratio * episodes.mask, dim=0)/torch.sum(episodes.mask,dim=0) loss = -torch.mean(weights*t_loss) else: loss = -weighted_mean(ratio * advantages, weights=episodes.mask) mask = episodes.mask if episodes.actions.dim() > 2: mask = mask.unsqueeze(2) kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask) return loss, kl, pi