예제 #1
0
 def post_process_fn(self, batch: Batch,
                     buffer: ReplayBuffer, indice: np.ndarray) -> None:
     """Post-process the data from the provided replay buffer. Typical
     usage is to update the sampling weight in prioritized experience
     replay. Check out :ref:`policy_concept` for more information.
     """
     if isinstance(buffer, PrioritizedReplayBuffer) \
             and hasattr(batch, 'weight'):
         buffer.update_weight(indice, batch.weight)
예제 #2
0
    def post_process_fn(self, batch: Batch, buffer: ReplayBuffer,
                        indice: np.ndarray) -> None:
        """Post-process the data from the provided replay buffer.

        Typical usage is to update the sampling weight in prioritized
        experience replay. Used in :meth:`update`.
        """
        if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
            buffer.update_weight(indice, batch.weight)
예제 #3
0
파일: dqn.py 프로젝트: HUST-WZY/tianshou
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        r"""Compute the n-step return for Q-learning targets:

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
            (Q_{new}(s_{t + n}, a)))

        , where :math:`\gamma` is the discount factor,
        :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
        :math:`t`. If there is no target network, the :math:`Q_{old}` is equal
        to :math:`Q_{new}`.
        """
        returns = np.zeros_like(indice)
        gammas = np.zeros_like(indice) + self._n_step
        for n in range(self._n_step - 1, -1, -1):
            now = (indice + n) % len(buffer)
            gammas[buffer.done[now] > 0] = n
            returns[buffer.done[now] > 0] = 0
            returns = buffer.rew[now] + self._gamma * returns
        terminal = (indice + self._n_step - 1) % len(buffer)
        terminal_data = buffer[terminal]
        if self._target:
            # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
            a = self(terminal_data, input='obs_next', eps=0).act
            target_q = self(terminal_data, model='model_old',
                            input='obs_next').logits
            if isinstance(target_q, torch.Tensor):
                target_q = target_q.detach().cpu().numpy()
            target_q = target_q[np.arange(len(a)), a]
        else:
            target_q = self(terminal_data, input='obs_next').logits
            if isinstance(target_q, torch.Tensor):
                target_q = target_q.detach().cpu().numpy()
            target_q = target_q.max(axis=1)
        target_q[gammas != self._n_step] = 0
        returns += (self._gamma**gammas) * target_q
        batch.returns = returns
        if isinstance(buffer, PrioritizedReplayBuffer):
            q = self(batch).logits
            q = q[np.arange(len(q)), batch.act]
            r = batch.returns
            if isinstance(r, np.ndarray):
                r = torch.tensor(r, device=q.device, dtype=q.dtype)
            td = r - q
            buffer.update_weight(indice, td.detach().cpu().numpy())
            impt_weight = torch.tensor(batch.impt_weight,
                                       device=q.device,
                                       dtype=torch.float)
            loss = (td.pow(2) * impt_weight).mean()
            if not hasattr(batch, 'loss'):
                batch.loss = loss
            else:
                batch.loss += loss
        return batch