Exemplo n.º 1
0
 def get_optimizer(self, stage: str, model) -> _Optimizer:
     fp16 = isinstance(model, Fp16Wrap)
     optimizer_params = (self.stages_config[stage].get(
         "optimizer_params", {}))
     optimizer = Registry.get_optimizer(model,
                                        **optimizer_params,
                                        fp16=fp16)
     return optimizer
Exemplo n.º 2
0
    def _init(self,
              critics,
              action_noise_std=0.2,
              action_noise_clip=0.5,
              values_range=(-10., 10.),
              critic_distribution=None,
              **kwargs):
        super()._init(**kwargs)
        # hack to prevent cycle dependencies
        from catalyst.contrib.registry import Registry

        self.n_atoms = self.critic.out_features
        self._loss_fn = self._base_loss

        self.action_noise_std = action_noise_std
        self.action_noise_clip = action_noise_clip

        critics = [x.to(self._device) for x in critics]
        critics_optimizer = [
            Registry.get_optimizer(x, **self.critic_optimizer_params)
            for x in critics
        ]
        critics_scheduler = [
            Registry.get_scheduler(x, **self.critic_scheduler_params)
            for x in critics_optimizer
        ]
        target_critics = [copy.deepcopy(x).to(self._device) for x in critics]

        self.critics = [self.critic] + critics
        self.critics_optimizer = [self.critic_optimizer] + critics_optimizer
        self.critics_scheduler = [self.critic_scheduler] + critics_scheduler
        self.target_critics = [self.target_critic] + target_critics

        if critic_distribution == "quantile":
            tau_min = 1 / (2 * self.n_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(start=tau_min,
                                 end=tau_max,
                                 steps=self.n_atoms)
            self.tau = self._to_tensor(tau)
            self._loss_fn = self._quantile_loss
        elif critic_distribution == "categorical":
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
            z = torch.linspace(start=self.v_min,
                               end=self.v_max,
                               steps=self.n_atoms)
            self.z = self._to_tensor(z)
            self._loss_fn = self._categorical_loss
Exemplo n.º 3
0
    def prepare_model_stuff(*,
                            model,
                            criterion_params=None,
                            optimizer_params=None,
                            scheduler_params=None):
        fp16 = isinstance(model, Fp16Wrap)

        criterion_params = criterion_params or {}
        criterion = Registry.get_criterion(**criterion_params)

        optimizer_params = optimizer_params or {}
        optimizer = Registry.get_optimizer(model,
                                           **optimizer_params,
                                           fp16=fp16)

        scheduler_params = scheduler_params or {}
        scheduler = Registry.get_scheduler(optimizer, **scheduler_params)

        return criterion, optimizer, scheduler
Exemplo n.º 4
0
    def _init(
        self,
        critics,
        reward_scale=1.0,
        values_range=(-10., 10.),
        critic_distribution=None,
        **kwargs
    ):
        """
        Parameters
        ----------
        reward_scale: float,
            THE MOST IMPORTANT HYPERPARAMETER which controls the ratio
            between maximizing rewards and acting as randomly as possible
        use_regularization: bool,
            whether to use l2 regularization on policy network outputs,
            regularization can not be used with RealNVPActor
        mu_and_sigma_reg: float,
            coefficient for l2 regularization on mu and log_sigma
        policy_grad_estimator: str,
            "reinforce": may be used with arbitrary explicit policy
            "reparametrization_trick": may be used with reparametrizable
            policy, e.g. Gaussian, normalizing flow (Real NVP).
        """
        super()._init(**kwargs)
        # hack to prevent cycle dependencies
        from catalyst.contrib.registry import Registry

        self.n_atoms = self.critic.out_features
        self._loss_fn = self._base_loss

        self.reward_scale = reward_scale
        # @TODO: policy regularization

        critics = [x.to(self._device) for x in critics]
        critics_optimizer = [
            Registry.get_optimizer(x, **self.critic_optimizer_params)
            for x in critics
        ]
        critics_scheduler = [
            Registry.get_scheduler(x, **self.critic_scheduler_params)
            for x in critics_optimizer
        ]
        target_critics = [copy.deepcopy(x).to(self._device) for x in critics]

        self.critics = [self.critic] + critics
        self.critics_optimizer = [self.critic_optimizer] + critics_optimizer
        self.critics_scheduler = [self.critic_scheduler] + critics_scheduler
        self.target_critics = [self.target_critic] + target_critics

        if critic_distribution == "quantile":
            tau_min = 1 / (2 * self.n_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(
                start=tau_min, end=tau_max, steps=self.n_atoms
            )
            self.tau = self._to_tensor(tau)
            self._loss_fn = self._quantile_loss
        elif critic_distribution == "categorical":
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self.n_atoms
            )
            self.z = self._to_tensor(z)
            self._loss_fn = self._categorical_loss
Exemplo n.º 5
0
    def __init__(self,
                 actor,
                 critic,
                 gamma,
                 n_step,
                 actor_optimizer_params,
                 critic_optimizer_params,
                 actor_grad_clip_params=None,
                 critic_grad_clip_params=None,
                 actor_loss_params=None,
                 critic_loss_params=None,
                 actor_scheduler_params=None,
                 critic_scheduler_params=None,
                 resume=None,
                 load_optimizer=True,
                 actor_tau=1.0,
                 critic_tau=1.0,
                 min_action=-1.0,
                 max_action=1.0,
                 **kwargs):
        # hack to prevent cycle dependencies
        from catalyst.contrib.registry import Registry

        self._device = UtilsFactory.prepare_device()

        self.actor = actor.to(self._device)
        self.critic = critic.to(self._device)

        self.target_actor = copy.deepcopy(actor).to(self._device)
        self.target_critic = copy.deepcopy(critic).to(self._device)

        self.actor_optimizer = Registry.get_optimizer(self.actor,
                                                      **actor_optimizer_params)
        self.critic_optimizer = Registry.get_optimizer(
            self.critic, **critic_optimizer_params)

        self.actor_optimizer_params = actor_optimizer_params
        self.critic_optimizer_params = critic_optimizer_params

        actor_scheduler_params = actor_scheduler_params or {}
        critic_scheduler_params = critic_scheduler_params or {}

        self.actor_scheduler = Registry.get_scheduler(self.actor_optimizer,
                                                      **actor_scheduler_params)
        self.critic_scheduler = Registry.get_scheduler(
            self.critic_optimizer, **critic_scheduler_params)

        self.actor_scheduler_params = actor_scheduler_params
        self.critic_scheduler_params = critic_scheduler_params

        self.n_step = n_step
        self.gamma = gamma

        actor_grad_clip_params = actor_grad_clip_params or {}
        critic_grad_clip_params = critic_grad_clip_params or {}

        self.actor_grad_clip_fn = Registry.get_grad_clip_fn(
            **actor_grad_clip_params)
        self.critic_grad_clip_fn = Registry.get_grad_clip_fn(
            **critic_grad_clip_params)

        self.actor_grad_clip_params = actor_grad_clip_params
        self.critic_grad_clip_params = critic_grad_clip_params

        self.actor_criterion = Registry.get_criterion(
            **(actor_loss_params or {}))
        self.critic_criterion = Registry.get_criterion(
            **(critic_loss_params or {}))

        self.actor_loss_params = actor_loss_params
        self.critic_loss_params = critic_loss_params

        self.actor_tau = actor_tau
        self.critic_tau = critic_tau

        self.min_action = min_action
        self.max_action = max_action

        self._init(**kwargs)

        if resume is not None:
            self.load_checkpoint(resume, load_optimizer=load_optimizer)