def _update_policy(self, batch: TensorDict) -> dict: with self.optimizers.optimize("actor"): loss, info = self.loss_actor(batch) loss.backward() info.update(self.extra_grad_info("actor")) main, target = self.module.actor, self.module.target_actor update_polyak(main, target, self.config["polyak"]) return info
def _update_critic(self, batch: TensorDict) -> dict: with self.optimizers.optimize("critics"): loss, info = self.loss_critic(batch) loss.backward() info.update(self.extra_grad_info("critics")) main, target = self.module.critics, self.module.target_critics update_polyak(main, target, self.config["polyak"]) return info
def improve_policy(self, batch: TensorDict) -> dict: with self.optimizers.optimize("naf"): loss, info = self.loss_fn(batch) loss.backward() info.update(self.extra_grad_info()) vcritics, target_vcritics = self.module.vcritics, self.module.target_vcritics update_polyak(vcritics, target_vcritics, self.config["polyak"]) return info
def improve_policy(self, batch: TensorDict): self._grad_step += 1 self._info["grad_steps"] = self._grad_step self._info.update(self._update_critic(batch)) if self._grad_step % self.config["policy_delay"] == 0: self._info.update(self._update_policy(batch)) critics, target_critics = self.module.critics, self.module.target_critics update_polyak(critics, target_critics, self.config["polyak"]) return self._info.copy()
def improve_policy(self, batch: TensorDict) -> dict: info = {} info.update(self._update_critic(batch)) info.update(self._update_actor(batch)) if self.config["target_entropy"] is not None: info.update(self._update_alpha(batch)) update_polyak(self.module.critics, self.module.target_critics, self.config["polyak"]) return info
def _update_polyak(self): update_polyak(self.module.critic, self.module.target_critic, self.config["polyak"])