示例#1
0
def get_trainer_components(
    *,
    agent,
    loss_params=None,
    optimizer_params=None,
    scheduler_params=None,
    grad_clip_params=None
):
    # criterion
    loss_params = _copy_params(loss_params)
    criterion = CRITERIONS.get_from_params(**loss_params)
    if criterion is not None \
            and torch.cuda.is_available():
        criterion = criterion.cuda()

    # optimizer
    agent_params = UtilsFactory.get_optimizable_params(
        agent.parameters())
    optimizer_params = _copy_params(optimizer_params)
    optimizer = OPTIMIZERS.get_from_params(
        **optimizer_params,
        params=agent_params
    )

    # scheduler
    scheduler_params = _copy_params(scheduler_params)
    scheduler = SCHEDULERS.get_from_params(
        **scheduler_params,
        optimizer=optimizer
    )

    # grad clipping
    grad_clip_params = _copy_params(grad_clip_params)
    grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

    result = {
        "loss_params": loss_params,
        "criterion": criterion,
        "optimizer_params": optimizer_params,
        "optimizer": optimizer,
        "scheduler_params": scheduler_params,
        "scheduler": scheduler,
        "grad_clip_params": grad_clip_params,
        "grad_clip_fn": grad_clip_fn
    }

    return result
示例#2
0
    def __init__(self,
                 grad_clip_params: Dict = None,
                 accumulation_steps: int = 1,
                 optimizer_key: str = None,
                 loss_key: str = None,
                 prefix: str = None):
        """
        @TODO: docs
        """

        grad_clip_params = grad_clip_params or {}
        self.grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

        self.accumulation_steps = accumulation_steps
        self.optimizer_key = optimizer_key
        self.loss_key = loss_key
        self.prefix = prefix
        self._optimizer_wd = 0
        self._accumulation_counter = 0
示例#3
0
    def __init__(
        self,
        grad_clip_params: Dict = None,
        fp16_grad_scale: float = 128.0,
        accumulation_steps: int = 1,
        optimizer_key: str = None,
        loss_key: str = None
    ):
        """
        @TODO: docs
        """

        grad_clip_params = grad_clip_params or {}
        self.grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

        self.fp16 = False
        self.fp16_grad_scale = fp16_grad_scale
        self.accumulation_steps = accumulation_steps
        self.optimizer_key = optimizer_key
        self.loss_key = loss_key
        self._optimizer_wd = 0
        self._accumulation_counter = 0
示例#4
0
    def __init__(
        self,
        actor,
        critic,
        gamma,
        n_step,
        actor_optimizer_params,
        critic_optimizer_params,
        actor_grad_clip_params=None,
        critic_grad_clip_params=None,
        actor_loss_params=None,
        critic_loss_params=None,
        actor_scheduler_params=None,
        critic_scheduler_params=None,
        resume=None,
        load_optimizer=True,
        actor_tau=1.0,
        critic_tau=1.0,
        min_action=-1.0,
        max_action=1.0,
        **kwargs
    ):
        self._device = UtilsFactory.prepare_device()

        self.actor = actor.to(self._device)
        self.critic = critic.to(self._device)

        self.target_actor = copy.deepcopy(actor).to(self._device)
        self.target_critic = copy.deepcopy(critic).to(self._device)

        self.actor_optimizer = OPTIMIZERS.get_from_params(
            **actor_optimizer_params,
            params=prepare_optimizable_params(self.actor)
        )
        self.critic_optimizer = OPTIMIZERS.get_from_params(
            **critic_optimizer_params,
            params=prepare_optimizable_params(self.critic)
        )
        self.actor_optimizer_params = actor_optimizer_params
        self.critic_optimizer_params = critic_optimizer_params

        actor_scheduler_params = actor_scheduler_params or {}
        critic_scheduler_params = critic_scheduler_params or {}

        self.actor_scheduler = SCHEDULERS.get_from_params(
            **actor_scheduler_params,
            optimizer=self.actor_optimizer
        )
        self.critic_scheduler = SCHEDULERS.get_from_params(
            **critic_scheduler_params,
            optimizer=self.critic_optimizer
        )

        self.actor_scheduler_params = actor_scheduler_params
        self.critic_scheduler_params = critic_scheduler_params

        self.n_step = n_step
        self.gamma = gamma

        actor_grad_clip_params = actor_grad_clip_params or {}
        critic_grad_clip_params = critic_grad_clip_params or {}

        self.actor_grad_clip_fn = \
            GRAD_CLIPPERS.get_from_params(**actor_grad_clip_params)
        self.critic_grad_clip_fn = \
            GRAD_CLIPPERS.get_from_params(**critic_grad_clip_params)

        self.actor_grad_clip_params = actor_grad_clip_params
        self.critic_grad_clip_params = critic_grad_clip_params

        self.actor_criterion = CRITERIONS.get_from_params(
            **(actor_loss_params or {})
        )
        self.critic_criterion = CRITERIONS.get_from_params(
            **(critic_loss_params or {})
        )

        self.actor_loss_params = actor_loss_params
        self.critic_loss_params = critic_loss_params

        self.actor_tau = actor_tau
        self.critic_tau = critic_tau

        self.min_action = min_action
        self.max_action = max_action

        self._init(**kwargs)

        if resume is not None:
            self.load_checkpoint(resume, load_optimizer=load_optimizer)