Esempio n. 1
0
    def _init(self,
              critics,
              action_noise_std=0.2,
              action_noise_clip=0.5,
              values_range=(-10., 10.),
              critic_distribution=None,
              **kwargs):
        super()._init(**kwargs)

        self.n_atoms = self.critic.out_features
        self._loss_fn = self._base_loss

        self.action_noise_std = action_noise_std
        self.action_noise_clip = action_noise_clip

        critics = [x.to(self._device) for x in critics]
        critics_optimizer = [
            OPTIMIZERS.get_from_params(**self.critic_optimizer_params,
                                       params=prepare_optimizable_params(x))
            for x in critics
        ]
        critics_scheduler = [
            SCHEDULERS.get_from_params(**self.critic_scheduler_params,
                                       optimizer=x) for x in critics_optimizer
        ]
        target_critics = [copy.deepcopy(x).to(self._device) for x in critics]

        self.critics = [self.critic] + critics
        self.critics_optimizer = [self.critic_optimizer] + critics_optimizer
        self.critics_scheduler = [self.critic_scheduler] + critics_scheduler
        self.target_critics = [self.target_critic] + target_critics

        if critic_distribution == "quantile":
            tau_min = 1 / (2 * self.n_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(start=tau_min,
                                 end=tau_max,
                                 steps=self.n_atoms)
            self.tau = self._to_tensor(tau)
            self._loss_fn = self._quantile_loss
        elif critic_distribution == "categorical":
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
            z = torch.linspace(start=self.v_min,
                               end=self.v_max,
                               steps=self.n_atoms)
            self.z = self._to_tensor(z)
            self._loss_fn = self._categorical_loss
Esempio n. 2
0
def get_trainer_components(
    *,
    agent,
    loss_params=None,
    optimizer_params=None,
    scheduler_params=None,
    grad_clip_params=None
):
    # criterion
    loss_params = _copy_params(loss_params)
    criterion = CRITERIONS.get_from_params(**loss_params)
    if criterion is not None \
            and torch.cuda.is_available():
        criterion = criterion.cuda()

    # optimizer
    agent_params = UtilsFactory.get_optimizable_params(
        agent.parameters())
    optimizer_params = _copy_params(optimizer_params)
    optimizer = OPTIMIZERS.get_from_params(
        **optimizer_params,
        params=agent_params
    )

    # scheduler
    scheduler_params = _copy_params(scheduler_params)
    scheduler = SCHEDULERS.get_from_params(
        **scheduler_params,
        optimizer=optimizer
    )

    # grad clipping
    grad_clip_params = _copy_params(grad_clip_params)
    grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

    result = {
        "loss_params": loss_params,
        "criterion": criterion,
        "optimizer_params": optimizer_params,
        "optimizer": optimizer,
        "scheduler_params": scheduler_params,
        "scheduler": scheduler,
        "grad_clip_params": grad_clip_params,
        "grad_clip_fn": grad_clip_fn
    }

    return result
Esempio n. 3
0
    def __init__(
        self,
        actor,
        critic,
        gamma,
        n_step,
        actor_optimizer_params,
        critic_optimizer_params,
        actor_grad_clip_params=None,
        critic_grad_clip_params=None,
        actor_loss_params=None,
        critic_loss_params=None,
        actor_scheduler_params=None,
        critic_scheduler_params=None,
        resume=None,
        load_optimizer=True,
        actor_tau=1.0,
        critic_tau=1.0,
        min_action=-1.0,
        max_action=1.0,
        **kwargs
    ):
        self._device = UtilsFactory.prepare_device()

        self.actor = actor.to(self._device)
        self.critic = critic.to(self._device)

        self.target_actor = copy.deepcopy(actor).to(self._device)
        self.target_critic = copy.deepcopy(critic).to(self._device)

        self.actor_optimizer = OPTIMIZERS.get_from_params(
            **actor_optimizer_params,
            params=prepare_optimizable_params(self.actor)
        )
        self.critic_optimizer = OPTIMIZERS.get_from_params(
            **critic_optimizer_params,
            params=prepare_optimizable_params(self.critic)
        )
        self.actor_optimizer_params = actor_optimizer_params
        self.critic_optimizer_params = critic_optimizer_params

        actor_scheduler_params = actor_scheduler_params or {}
        critic_scheduler_params = critic_scheduler_params or {}

        self.actor_scheduler = SCHEDULERS.get_from_params(
            **actor_scheduler_params,
            optimizer=self.actor_optimizer
        )
        self.critic_scheduler = SCHEDULERS.get_from_params(
            **critic_scheduler_params,
            optimizer=self.critic_optimizer
        )

        self.actor_scheduler_params = actor_scheduler_params
        self.critic_scheduler_params = critic_scheduler_params

        self.n_step = n_step
        self.gamma = gamma

        actor_grad_clip_params = actor_grad_clip_params or {}
        critic_grad_clip_params = critic_grad_clip_params or {}

        self.actor_grad_clip_fn = \
            GRAD_CLIPPERS.get_from_params(**actor_grad_clip_params)
        self.critic_grad_clip_fn = \
            GRAD_CLIPPERS.get_from_params(**critic_grad_clip_params)

        self.actor_grad_clip_params = actor_grad_clip_params
        self.critic_grad_clip_params = critic_grad_clip_params

        self.actor_criterion = CRITERIONS.get_from_params(
            **(actor_loss_params or {})
        )
        self.critic_criterion = CRITERIONS.get_from_params(
            **(critic_loss_params or {})
        )

        self.actor_loss_params = actor_loss_params
        self.critic_loss_params = critic_loss_params

        self.actor_tau = actor_tau
        self.critic_tau = critic_tau

        self.min_action = min_action
        self.max_action = max_action

        self._init(**kwargs)

        if resume is not None:
            self.load_checkpoint(resume, load_optimizer=load_optimizer)