def post_process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> None: """Post-process the data from the provided replay buffer. Typical usage is to update the sampling weight in prioritized experience replay. Check out :ref:`policy_concept` for more information. """ if isinstance(buffer, PrioritizedReplayBuffer) \ and hasattr(batch, 'weight'): buffer.update_weight(indice, batch.weight)
def post_process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> None: """Post-process the data from the provided replay buffer. Typical usage is to update the sampling weight in prioritized experience replay. Used in :meth:`update`. """ if hasattr(buffer, "update_weight") and hasattr(batch, "weight"): buffer.update_weight(indice, batch.weight)
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: r"""Compute the n-step return for Q-learning targets: .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a (Q_{new}(s_{t + n}, a))) , where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. If there is no target network, the :math:`Q_{old}` is equal to :math:`Q_{new}`. """ returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + self._n_step for n in range(self._n_step - 1, -1, -1): now = (indice + n) % len(buffer) gammas[buffer.done[now] > 0] = n returns[buffer.done[now] > 0] = 0 returns = buffer.rew[now] + self._gamma * returns terminal = (indice + self._n_step - 1) % len(buffer) terminal_data = buffer[terminal] if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) a = self(terminal_data, input='obs_next', eps=0).act target_q = self(terminal_data, model='model_old', input='obs_next').logits if isinstance(target_q, torch.Tensor): target_q = target_q.detach().cpu().numpy() target_q = target_q[np.arange(len(a)), a] else: target_q = self(terminal_data, input='obs_next').logits if isinstance(target_q, torch.Tensor): target_q = target_q.detach().cpu().numpy() target_q = target_q.max(axis=1) target_q[gammas != self._n_step] = 0 returns += (self._gamma**gammas) * target_q batch.returns = returns if isinstance(buffer, PrioritizedReplayBuffer): q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = batch.returns if isinstance(r, np.ndarray): r = torch.tensor(r, device=q.device, dtype=q.dtype) td = r - q buffer.update_weight(indice, td.detach().cpu().numpy()) impt_weight = torch.tensor(batch.impt_weight, device=q.device, dtype=torch.float) loss = (td.pow(2) * impt_weight).mean() if not hasattr(batch, 'loss'): batch.loss = loss else: batch.loss += loss return batch