def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q) if hasattr(batch, 'update_weight'): td = r - q batch.update_weight(batch.indice, to_numpy(td)) impt_weight = to_torch_as(batch.impt_weight, q) loss = (td.pow(2) * impt_weight).mean() else: loss = F.mse_loss(q, r) loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: r"""Compute the n-step return for Q-learning targets: .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a (Q_{new}(s_{t + n}, a))) , where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. If there is no target network, the :math:`Q_{old}` is equal to :math:`Q_{new}`. """ batch = self.compute_nstep_return(batch, buffer, indice, self._target_q, self._gamma, self._n_step) if isinstance(buffer, PrioritizedReplayBuffer): batch.update_weight = buffer.update_weight batch.indice = indice return batch