def _init(self, critics, action_noise_std=0.2, action_noise_clip=0.5, values_range=(-10., 10.), critic_distribution=None, **kwargs): super()._init(**kwargs) self.n_atoms = self.critic.out_features self._loss_fn = self._base_loss self.action_noise_std = action_noise_std self.action_noise_clip = action_noise_clip critics = [x.to(self._device) for x in critics] critics_optimizer = [ OPTIMIZERS.get_from_params(**self.critic_optimizer_params, params=prepare_optimizable_params(x)) for x in critics ] critics_scheduler = [ SCHEDULERS.get_from_params(**self.critic_scheduler_params, optimizer=x) for x in critics_optimizer ] target_critics = [copy.deepcopy(x).to(self._device) for x in critics] self.critics = [self.critic] + critics self.critics_optimizer = [self.critic_optimizer] + critics_optimizer self.critics_scheduler = [self.critic_scheduler] + critics_scheduler self.target_critics = [self.target_critic] + target_critics if critic_distribution == "quantile": tau_min = 1 / (2 * self.n_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self.n_atoms) self.tau = self._to_tensor(tau) self._loss_fn = self._quantile_loss elif critic_distribution == "categorical": self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self.n_atoms) self.z = self._to_tensor(z) self._loss_fn = self._categorical_loss
def get_trainer_components( *, agent, loss_params=None, optimizer_params=None, scheduler_params=None, grad_clip_params=None ): # criterion loss_params = _copy_params(loss_params) criterion = CRITERIONS.get_from_params(**loss_params) if criterion is not None \ and torch.cuda.is_available(): criterion = criterion.cuda() # optimizer agent_params = UtilsFactory.get_optimizable_params( agent.parameters()) optimizer_params = _copy_params(optimizer_params) optimizer = OPTIMIZERS.get_from_params( **optimizer_params, params=agent_params ) # scheduler scheduler_params = _copy_params(scheduler_params) scheduler = SCHEDULERS.get_from_params( **scheduler_params, optimizer=optimizer ) # grad clipping grad_clip_params = _copy_params(grad_clip_params) grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params) result = { "loss_params": loss_params, "criterion": criterion, "optimizer_params": optimizer_params, "optimizer": optimizer, "scheduler_params": scheduler_params, "scheduler": scheduler, "grad_clip_params": grad_clip_params, "grad_clip_fn": grad_clip_fn } return result
def __init__( self, actor, critic, gamma, n_step, actor_optimizer_params, critic_optimizer_params, actor_grad_clip_params=None, critic_grad_clip_params=None, actor_loss_params=None, critic_loss_params=None, actor_scheduler_params=None, critic_scheduler_params=None, resume=None, load_optimizer=True, actor_tau=1.0, critic_tau=1.0, min_action=-1.0, max_action=1.0, **kwargs ): self._device = UtilsFactory.prepare_device() self.actor = actor.to(self._device) self.critic = critic.to(self._device) self.target_actor = copy.deepcopy(actor).to(self._device) self.target_critic = copy.deepcopy(critic).to(self._device) self.actor_optimizer = OPTIMIZERS.get_from_params( **actor_optimizer_params, params=prepare_optimizable_params(self.actor) ) self.critic_optimizer = OPTIMIZERS.get_from_params( **critic_optimizer_params, params=prepare_optimizable_params(self.critic) ) self.actor_optimizer_params = actor_optimizer_params self.critic_optimizer_params = critic_optimizer_params actor_scheduler_params = actor_scheduler_params or {} critic_scheduler_params = critic_scheduler_params or {} self.actor_scheduler = SCHEDULERS.get_from_params( **actor_scheduler_params, optimizer=self.actor_optimizer ) self.critic_scheduler = SCHEDULERS.get_from_params( **critic_scheduler_params, optimizer=self.critic_optimizer ) self.actor_scheduler_params = actor_scheduler_params self.critic_scheduler_params = critic_scheduler_params self.n_step = n_step self.gamma = gamma actor_grad_clip_params = actor_grad_clip_params or {} critic_grad_clip_params = critic_grad_clip_params or {} self.actor_grad_clip_fn = \ GRAD_CLIPPERS.get_from_params(**actor_grad_clip_params) self.critic_grad_clip_fn = \ GRAD_CLIPPERS.get_from_params(**critic_grad_clip_params) self.actor_grad_clip_params = actor_grad_clip_params self.critic_grad_clip_params = critic_grad_clip_params self.actor_criterion = CRITERIONS.get_from_params( **(actor_loss_params or {}) ) self.critic_criterion = CRITERIONS.get_from_params( **(critic_loss_params or {}) ) self.actor_loss_params = actor_loss_params self.critic_loss_params = critic_loss_params self.actor_tau = actor_tau self.critic_tau = critic_tau self.min_action = min_action self.max_action = max_action self._init(**kwargs) if resume is not None: self.load_checkpoint(resume, load_optimizer=load_optimizer)