def _init(self, critics: List[CriticSpec], reward_scale: float = 1.0): self.reward_scale = reward_scale # @TODO: policy regularization critics = [x.to(self._device) for x in critics] target_critics = [copy.deepcopy(x).to(self._device) for x in critics] critics_optimizer = [] critics_scheduler = [] for critic in critics: critic_components = utils.get_trainer_components( agent=critic, loss_params=self._critic_loss_params, optimizer_params=self._critic_optimizer_params, scheduler_params=self._critic_scheduler_params, grad_clip_params=self._critic_grad_clip_params) critics_optimizer.append(critic_components["optimizer"]) critics_scheduler.append(critic_components["scheduler"]) self.critics = [self.critic] + critics self.critics_optimizer = [self.critic_optimizer] + critics_optimizer self.critics_scheduler = [self.critic_scheduler] + critics_scheduler self.target_critics = [self.target_critic] + target_critics # value distribution approximation critic_distribution = self.critic.distribution self._loss_fn = self._base_loss self._num_heads = self.critic.num_heads self._num_critics = len(self.critics) self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) self._gammas = utils.any2device(self._gammas, device=self._device) assert critic_distribution in [None, "categorical", "quantile"] if critic_distribution == "categorical": self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self.num_atoms) self.z = utils.any2device(z, device=self._device) self._loss_fn = self._categorical_loss elif critic_distribution == "quantile": self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self.num_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self.num_atoms) self.tau = utils.any2device(tau, device=self._device) self._loss_fn = self._quantile_loss else: assert self.critic_criterion is not None
def _init(self, use_value_clipping: bool = True, gae_lambda: float = 0.95, clip_eps: float = 0.2, entropy_regularization: float = None): self.use_value_clipping = use_value_clipping self.gae_lambda = gae_lambda self.clip_eps = clip_eps self.entropy_regularization = entropy_regularization critic_distribution = self.critic.distribution self._value_loss_fn = self._base_value_loss self._num_atoms = self.critic.num_atoms self._num_heads = self.critic.num_heads self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) # 1 x num_heads x 1 self._gammas_torch = utils.any2device(self._gammas, device=self._device)[None, :, None] if critic_distribution == "categorical": self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self._num_atoms) self.z = utils.any2device(z, device=self._device) self._value_loss_fn = self._categorical_value_loss elif critic_distribution == "quantile": assert self.critic_criterion is not None self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self._num_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self._num_atoms) self.tau = utils.any2device(tau, device=self._device) self._value_loss_fn = self._quantile_value_loss if not self.use_value_clipping: assert self.critic_criterion is not None
def _init(self, entropy_regularization: float = None): self.entropy_regularization = entropy_regularization # value distribution approximation critic_distribution = self.critic.distribution self._loss_fn = self._base_loss self._num_heads = self.critic.num_heads self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) self._gammas = utils.any2device(self._gammas, device=self._device) assert critic_distribution in [None, "categorical", "quantile"] if critic_distribution == "categorical": assert self.critic_criterion is None self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) z = torch.linspace( start=self.v_min, end=self.v_max, steps=self.num_atoms ) self.z = utils.any2device(z, device=self._device) self._loss_fn = self._categorical_loss elif critic_distribution == "quantile": assert self.critic_criterion is not None self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self.num_atoms) tau_max = 1 - tau_min tau = torch.linspace( start=tau_min, end=tau_max, steps=self.num_atoms ) self.tau = utils.any2device(tau, device=self._device) self._loss_fn = self._quantile_loss else: assert self.critic_criterion is not None