def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: r"""Compute the discounted returns for each transition. .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ v_s_ = np.full(indice.shape, self.ret_rms.mean) unnormalized_returns, _ = self.compute_episodic_return( batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns return batch
def _compute_returns(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: v_s, v_s_ = [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = batch.v_s.cpu().numpy() v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() # when normalizing values, we do not minus self.ret_rms.mean to be numerically # consistent with OPENAI baselines' value normalization pipeline. Emperical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s, v_s_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next)) old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0]))) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value v_s = to_numpy(batch.v_s) v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) if self._rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean unnormalized_returns, advantages = self.compute_episodic_return( batch, buffer, indice, v_s_, v_s, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) mean, std = np.mean(advantages), np.std(advantages) advantages = (advantages - mean) / std # per-batch norm else: batch.returns = unnormalized_returns batch.act = to_torch_as(batch.act, batch.v_s) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch
def compute_episodic_return( batch: Batch, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimation (arXiv:1506.02438). :param batch: a data batch which contains several full-episode data chronologically. :type batch: :class:`~tianshou.data.Batch` :param v_s_: the value function of all next states :math:`V(s')`. :type v_s_: numpy.ndarray :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. """ if v_s_ is None: v_s_ = np.zeros_like(batch.rew) else: if not isinstance(v_s_, np.ndarray): v_s_ = np.array(v_s_, np.float) v_s_ = v_s_.reshape(batch.rew.shape) batch.returns = np.roll(v_s_, 1, axis=0) m = (1. - batch.done) * gamma delta = batch.rew + v_s_ * m - batch.returns m *= gae_lambda gae = 0. for i in range(len(batch.rew) - 1, -1, -1): gae = delta[i] + m[i] * gae batch.returns[i] += gae return batch
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size r = batch.returns if self._rew_norm and r.std() > self.__eps: batch.returns = (r - r.mean()) / r.std() losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs) a = torch.tensor(b.act, device=v.device) r = torch.tensor(b.returns, device=v.device) a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() vf_loss = F.mse_loss(r[:, None], v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm: nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { 'loss': losses, 'loss/actor': actor_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }
def compute_episodic_return( batch: Batch, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: """Compute returns over given full-length episodes. Implementation of Generalized Advantage Estimator (arXiv:1506.02438). :param batch: a data batch which contains several full-episode data chronologically. :type batch: :class:`~tianshou.data.Batch` :param v_s_: the value function of all next states :math:`V(s')`. :type v_s_: numpy.ndarray :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). """ rew = batch.rew v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten()) returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() if not np.isclose(std, 0, 1e-2): batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( to_torch_as(b.act, v[0]))) v_ = to_numpy(torch.cat(v_, dim=0)) batch = self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, v[0]) batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() if not np.isclose(std.item(), 0, 1e-2): batch.adv = (batch.adv - mean) / std return batch
def compute_nstep_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, ) -> Batch: r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` :param buffer: a data buffer which contains several full-episode data chronologically. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray :param function target_q_fn: a function receives :math:`t+n-1` step's data and compute target Q value. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). """ rew = buffer.rew if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0, 1e-2): mean, std = 0.0, 1.0 else: mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) target_q = to_numpy(target_q_torch) target_q = _nstep_return(rew, buffer.done, target_q, indice, gamma, n_step, len(buffer), mean, std) batch.returns = to_torch_as(target_q, target_q_torch) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def compute_nstep_return(batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False) -> np.ndarray: r"""Compute n-step return for Q-learning targets: .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) , where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param batch: a data batch, which is equal to buffer[indice]. :type batch: :class:`~tianshou.data.Batch` :param buffer: a data buffer which contains several full-episode data chronologically. :type buffer: :class:`~tianshou.data.ReplayBuffer` :param indice: sampled timestep. :type indice: numpy.ndarray :param function target_q_fn: a function receives :math:`t+n-1` step's data and compute target Q value. :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to ``False``. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). """ if rew_norm: bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0): mean, std = 0, 1 else: mean, std = 0, 1 returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + n_step done, rew, buf_len = buffer.done, buffer.rew, len(buffer) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len target_q = target_q_fn(buffer, terminal).squeeze() target_q[gammas != n_step] = 0 returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma**gammas, target_q) batch.returns = target_q * gammas + returns return batch
def compute_nstep_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, rew_norm: bool = False, use_mixed: bool = False, ) -> Batch: r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. :param Batch batch: a data batch, which is equal to buffer[indice]. :param ReplayBuffer buffer: the data buffer. :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param int n_step: the number of estimation step, should be an int greater than 0. Default to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. """ assert not rew_norm, \ "Reward normalization in computing n-step returns is unsupported now." rew = buffer.rew bsz = len(indice) indices = [indice] for _ in range(n_step - 1): indices.append(buffer.next(indices[-1])) indices = np.stack(indices) # terminal indicates buffer indexes nstep after 'indice', # and are truncated at the end of each episode terminal = indices[-1] with autocast(enabled=use_mixed): with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch.float().reshape(bsz, -1)) target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape( -1, 1) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: r"""Compute the n-step return for Q-learning targets: .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a (Q_{new}(s_{t + n}, a))) , where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. If there is no target network, the :math:`Q_{old}` is equal to :math:`Q_{new}`. """ returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + self._n_step for n in range(self._n_step - 1, -1, -1): now = (indice + n) % len(buffer) gammas[buffer.done[now] > 0] = n returns[buffer.done[now] > 0] = 0 returns = buffer.rew[now] + self._gamma * returns terminal = (indice + self._n_step - 1) % len(buffer) terminal_data = buffer[terminal] if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) a = self(terminal_data, input='obs_next', eps=0).act target_q = self(terminal_data, model='model_old', input='obs_next').logits if isinstance(target_q, torch.Tensor): target_q = target_q.detach().cpu().numpy() target_q = target_q[np.arange(len(a)), a] else: target_q = self(terminal_data, input='obs_next').logits if isinstance(target_q, torch.Tensor): target_q = target_q.detach().cpu().numpy() target_q = target_q.max(axis=1) target_q[gammas != self._n_step] = 0 returns += (self._gamma**gammas) * target_q batch.returns = returns if isinstance(buffer, PrioritizedReplayBuffer): q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = batch.returns if isinstance(r, np.ndarray): r = torch.tensor(r, device=q.device, dtype=q.dtype) td = r - q buffer.update_weight(indice, td.detach().cpu().numpy()) impt_weight = torch.tensor(batch.impt_weight, device=q.device, dtype=torch.float) loss = (td.pow(2) * impt_weight).mean() if not hasattr(batch, 'loss'): batch.loss = loss else: batch.loss += loss return batch
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s_.append(to_numpy(self.critic(b.obs_next))) v_s_ = np.concatenate(v_s_, axis=0) if self._rew_norm: # unnormalize v_s_ v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean unnormalized_returns, _ = self.compute_episodic_return( batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns return batch
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] r = batch.returns if self._rew_norm and not np.isclose(r.std(), 0): batch.returns = (r - r.mean()) / r.std() for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist a = torch.tensor(b.act, device=dist.logits.device) r = torch.tensor(b.returns, device=dist.logits.device) loss = -(dist.log_prob(a) * r).sum() loss.backward() self.optim.step() losses.append(loss.item()) return {'loss': losses}
def compute_episodic_return( batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: """Compute returns over given batch. Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to buffer[indice]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1]. Default to 0.95. :param bool rew_norm: normalize the reward to Normal(0, 1). Default to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). """ rew = batch.rew if v_s_ is None: assert np.isclose(gae_lambda, 1.0) v_s_ = np.zeros_like(rew) else: v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask( buffer, indice) end_flag = batch.done.copy() end_flag[np.isin(indice, buffer.unfinished_index())] = True returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch
def compute_episodic_return( batch: Batch, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimator (arXiv:1506.02438). :param batch: a data batch which contains several full-episode data chronologically. :type batch: :class:`~tianshou.data.Batch` :param v_s_: the value function of all next states :math:`V(s')`. :type v_s_: numpy.ndarray :param float gamma: the discount factor, should be in [0, 1], defaults to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults to ``False``. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). """ rew = batch.rew v_s_ = rew * 0. if v_s_ is None else to_numpy(v_s_).flatten() returns = np.roll(v_s_, 1, axis=0) m = (1. - batch.done) * gamma delta = rew + v_s_ * m - returns m *= gae_lambda gae = 0. for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae returns[i] += gae if rew_norm and not np.isclose(returns.std(), 0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch
def _compute_return( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray, gamma: float = 0.99, ) -> Batch: rew = batch.rew with torch.no_grad(): target_q_torch = self._target_q(buffer, indice) # (bsz, ?) target_q = to_numpy(target_q_torch) end_flag = buffer.done.copy() end_flag[buffer.unfinished_index()] = True end_flag = end_flag[indice] mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] v = [] old_log_prob = [] with torch.no_grad(): for b in batch.split(batch_size, shuffle=False): v.append(self.critic(b.obs)) old_log_prob.append( self(b).dist.log_prob( torch.tensor(b.act, device=v[0].device))) batch.v = torch.cat(v, dim=0) # old value dev = batch.v.device batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = torch.tensor(batch.returns, dtype=torch.float, device=dev).reshape(batch.v.shape) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() if std > self.__eps: batch.returns = (batch.returns - mean) / std batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() if std > self.__eps: batch.adv = (batch.adv - mean) / std for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist value = self.critic(b.obs) ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() surr1 = ratio * b.adv surr2 = ratio.clamp(1. - self._eps_clip, 1. + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() else: clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) if self._value_clip: v_clip = b.v + (value - b.v).clamp(-self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = .5 * torch.max(vf1, vf2).mean() else: vf_loss = .5 * (b.returns - value).pow(2).mean() vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm) self.optim.step() return { 'loss': losses, 'loss/clip': clip_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }