Exemplo n.º 1
0
def get_trainer_components(
    *,
    agent,
    loss_params=None,
    optimizer_params=None,
    scheduler_params=None,
    grad_clip_params=None
):
    # criterion
    loss_params = _copy_params(loss_params)
    criterion = CRITERIONS.get_from_params(**loss_params)
    if criterion is not None \
            and torch.cuda.is_available():
        criterion = criterion.cuda()

    # optimizer
    agent_params = UtilsFactory.get_optimizable_params(
        agent.parameters())
    optimizer_params = _copy_params(optimizer_params)
    optimizer = OPTIMIZERS.get_from_params(
        **optimizer_params,
        params=agent_params
    )

    # scheduler
    scheduler_params = _copy_params(scheduler_params)
    scheduler = SCHEDULERS.get_from_params(
        **scheduler_params,
        optimizer=optimizer
    )

    # grad clipping
    grad_clip_params = _copy_params(grad_clip_params)
    grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

    result = {
        "loss_params": loss_params,
        "criterion": criterion,
        "optimizer_params": optimizer_params,
        "optimizer": optimizer,
        "scheduler_params": scheduler_params,
        "scheduler": scheduler,
        "grad_clip_params": grad_clip_params,
        "grad_clip_fn": grad_clip_fn
    }

    return result
Exemplo n.º 2
0
    def __init__(
        self,
        actor,
        critic,
        gamma,
        n_step,
        actor_optimizer_params,
        critic_optimizer_params,
        actor_grad_clip_params=None,
        critic_grad_clip_params=None,
        actor_loss_params=None,
        critic_loss_params=None,
        actor_scheduler_params=None,
        critic_scheduler_params=None,
        resume=None,
        load_optimizer=True,
        actor_tau=1.0,
        critic_tau=1.0,
        min_action=-1.0,
        max_action=1.0,
        **kwargs
    ):
        self._device = UtilsFactory.prepare_device()

        self.actor = actor.to(self._device)
        self.critic = critic.to(self._device)

        self.target_actor = copy.deepcopy(actor).to(self._device)
        self.target_critic = copy.deepcopy(critic).to(self._device)

        self.actor_optimizer = OPTIMIZERS.get_from_params(
            **actor_optimizer_params,
            params=prepare_optimizable_params(self.actor)
        )
        self.critic_optimizer = OPTIMIZERS.get_from_params(
            **critic_optimizer_params,
            params=prepare_optimizable_params(self.critic)
        )
        self.actor_optimizer_params = actor_optimizer_params
        self.critic_optimizer_params = critic_optimizer_params

        actor_scheduler_params = actor_scheduler_params or {}
        critic_scheduler_params = critic_scheduler_params or {}

        self.actor_scheduler = SCHEDULERS.get_from_params(
            **actor_scheduler_params,
            optimizer=self.actor_optimizer
        )
        self.critic_scheduler = SCHEDULERS.get_from_params(
            **critic_scheduler_params,
            optimizer=self.critic_optimizer
        )

        self.actor_scheduler_params = actor_scheduler_params
        self.critic_scheduler_params = critic_scheduler_params

        self.n_step = n_step
        self.gamma = gamma

        actor_grad_clip_params = actor_grad_clip_params or {}
        critic_grad_clip_params = critic_grad_clip_params or {}

        self.actor_grad_clip_fn = \
            GRAD_CLIPPERS.get_from_params(**actor_grad_clip_params)
        self.critic_grad_clip_fn = \
            GRAD_CLIPPERS.get_from_params(**critic_grad_clip_params)

        self.actor_grad_clip_params = actor_grad_clip_params
        self.critic_grad_clip_params = critic_grad_clip_params

        self.actor_criterion = CRITERIONS.get_from_params(
            **(actor_loss_params or {})
        )
        self.critic_criterion = CRITERIONS.get_from_params(
            **(critic_loss_params or {})
        )

        self.actor_loss_params = actor_loss_params
        self.critic_loss_params = critic_loss_params

        self.actor_tau = actor_tau
        self.critic_tau = critic_tau

        self.min_action = min_action
        self.max_action = max_action

        self._init(**kwargs)

        if resume is not None:
            self.load_checkpoint(resume, load_optimizer=load_optimizer)