예제 #1
0
 def get_schedulers(self, model):
     if not self.is_empty:
         try:
             schedulers_out = {}
             schedulers_config = self.schedulers
             for scheduler_type, (
                     scheduler_opt,
                     scheduler_state) in schedulers_config.items():
                 if scheduler_type == "lr_scheduler":
                     optimizer = model.optimizer
                     scheduler = instantiate_scheduler(
                         optimizer, scheduler_opt)
                     scheduler.load_state_dict(scheduler_state)
                     schedulers_out["lr_scheduler"] = scheduler
                 elif scheduler_type == "bn_scheduler":
                     scheduler = instantiate_bn_scheduler(
                         model, scheduler_opt)
                     scheduler.load_state_dict(scheduler_state)
                     schedulers_out["bn_scheduler"] = scheduler
                 else:
                     raise NotImplementedError
             return schedulers_out
         except:
             log.warn("The checkpoint doesn t contain schedulers")
             return None
    def load_optim_sched(self, model, load_state=True):
        if not self.is_empty:
            # initialize optimizer
            optimizer_config = self.optimizer
            optimizer_cls = getattr(torch.optim, optimizer_config[0])
            optimizer_params = OmegaConf.create(
                self.run_config).training.optim.optimizer.params
            model.optimizer = optimizer_cls(model.parameters(),
                                            **optimizer_params)

            # initialize & load schedulersr
            schedulers_out = {}
            schedulers_config = self.schedulers
            for scheduler_type, (scheduler_opt,
                                 scheduler_state) in schedulers_config.items():
                if scheduler_type == "lr_scheduler":
                    optimizer = model.optimizer
                    scheduler = instantiate_scheduler(
                        optimizer, OmegaConf.create(scheduler_opt))
                    if load_state:
                        scheduler.load_state_dict(scheduler_state)
                    schedulers_out["lr_scheduler"] = scheduler

                elif scheduler_type == "bn_scheduler":
                    scheduler = instantiate_bn_scheduler(
                        model, OmegaConf.create(scheduler_opt))
                    if load_state:
                        scheduler.load_state_dict(scheduler_state)
                    schedulers_out["bn_scheduler"] = scheduler

            # load optimizer
            model.schedulers = schedulers_out
            if load_state:
                model.optimizer.load_state_dict(optimizer_config[1])
예제 #3
0
    def instantiate_optimizers(self, config):
        # Optimiser
        optimizer_opt = self.get_from_opt(
            config,
            ["training", "optim", "optimizer"],
            msg_err="optimizer needs to be defined within the training config",
        )
        optmizer_cls_name = optimizer_opt.get("class")
        optimizer_cls = getattr(torch.optim, optmizer_cls_name)
        optimizer_params = {}
        if hasattr(optimizer_opt, "params"):
            optimizer_params = optimizer_opt.params
        self._optimizer = optimizer_cls(self.parameters(), **optimizer_params)

        # LR Scheduler
        scheduler_opt = self.get_from_opt(
            config, ["training", "optim", "lr_scheduler"])
        if scheduler_opt:
            update_lr_scheduler_on = config.update_lr_scheduler_on
            if update_lr_scheduler_on:
                self._update_lr_scheduler_on = update_lr_scheduler_on
            scheduler_opt.update_scheduler_on = self._update_lr_scheduler_on
            lr_scheduler = instantiate_scheduler(self._optimizer,
                                                 scheduler_opt)
            self._add_scheduler("lr_scheduler", lr_scheduler)

        # BN Scheduler
        bn_scheduler_opt = self.get_from_opt(
            config, ["training", "optim", "bn_scheduler"])
        if bn_scheduler_opt:
            update_bn_scheduler_on = config.update_bn_scheduler_on
            if update_bn_scheduler_on:
                self._update_bn_scheduler_on = update_bn_scheduler_on
            bn_scheduler_opt.update_scheduler_on = self._update_bn_scheduler_on
            bn_scheduler = instantiate_bn_scheduler(self, bn_scheduler_opt)
            self._add_scheduler("bn_scheduler", bn_scheduler)

        # Accumulated gradients
        self._accumulated_gradient_step = self.get_from_opt(
            config, ["training", "optim", "accumulated_gradient"])
        if self._accumulated_gradient_step:
            if self._accumulated_gradient_step > 1:
                self._accumulated_gradient_count = 0
            else:
                raise Exception(
                    "When set, accumulated_gradient option should be an integer greater than 1"
                )

        # Gradient clipping
        self._grad_clip = self.get_from_opt(config,
                                            ["training", "optim", "grad_clip"],
                                            default_value=-1)
예제 #4
0
    def instantiate_optimizers(self, config, cuda_enabled=False):
        # Optimiser
        optimizer_opt = self.get_from_opt(
            config,
            ["training", "optim", "optimizer"],
            msg_err="optimizer needs to be defined within the training config",
        )
        optmizer_cls_name = optimizer_opt.get("class")
        optimizer_cls = getattr(torch.optim, optmizer_cls_name)
        optimizer_params = {}
        if hasattr(optimizer_opt, "params"):
            optimizer_params = optimizer_opt.params
        self._optimizer = optimizer_cls(self.parameters(), **optimizer_params)

        # LR Scheduler
        scheduler_opt = self.get_from_opt(
            config, ["training", "optim", "lr_scheduler"])
        if scheduler_opt:
            update_lr_scheduler_on = config.get(
                'update_lr_scheduler_on')  # Update to OmegaConf 2.0
            if update_lr_scheduler_on:
                self._update_lr_scheduler_on = update_lr_scheduler_on
            scheduler_opt.update_scheduler_on = self._update_lr_scheduler_on
            lr_scheduler = instantiate_scheduler(self._optimizer,
                                                 scheduler_opt)
            self._add_scheduler("lr_scheduler", lr_scheduler)

        # BN Scheduler
        bn_scheduler_opt = self.get_from_opt(
            config, ["training", "optim", "bn_scheduler"])
        if bn_scheduler_opt:
            update_bn_scheduler_on = config.get(
                'update_bn_scheduler_on')  # update to OmegaConf 2.0
            if update_bn_scheduler_on:
                self._update_bn_scheduler_on = update_bn_scheduler_on
            bn_scheduler_opt.update_scheduler_on = self._update_bn_scheduler_on
            bn_scheduler = instantiate_bn_scheduler(self, bn_scheduler_opt)
            self._add_scheduler("bn_scheduler", bn_scheduler)

        # Accumulated gradients
        self._accumulated_gradient_step = self.get_from_opt(
            config, ["training", "optim", "accumulated_gradient"])
        if self._accumulated_gradient_step:
            if self._accumulated_gradient_step > 1:
                self._accumulated_gradient_count = 0
            else:
                raise Exception(
                    "When set, accumulated_gradient option should be an integer greater than 1"
                )

        # Gradient clipping
        self._grad_clip = self.get_from_opt(config,
                                            ["training", "optim", "grad_clip"],
                                            default_value=-1)

        # Gradient Scaling
        self._enable_mixed = self.get_from_opt(config,
                                               ["training", "enable_mixed"],
                                               default_value=False)
        self._enable_mixed = bool(self._enable_mixed)
        if self.is_mixed_precision() and not cuda_enabled:
            log.warning(
                "Mixed precision is not supported on this device, using default precision..."
            )
            self._enable_mixed = False
        elif self._enable_mixed and not self._supports_mixed:
            log.warning(
                "Mixed precision is not supported on this model, using default precision..."
            )
        elif self.is_mixed_precision():
            log.info("Model will use mixed precision")

        self._grad_scale = torch.cuda.amp.GradScaler(
            enabled=self.is_mixed_precision())