def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch: """Pre-process the data from the provided replay buffer. Used in :meth:`update`. Check out :ref:`process_fn` for more information. """ mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) batch.rew += to_numpy(mse_loss * self.reward_scale) return self.policy.process_fn(batch, buffer, indices)