def get_schedulers(self, model): if not self.is_empty: try: schedulers_out = {} schedulers_config = self.schedulers for scheduler_type, ( scheduler_opt, scheduler_state) in schedulers_config.items(): if scheduler_type == "lr_scheduler": optimizer = model.optimizer scheduler = instantiate_scheduler( optimizer, scheduler_opt) scheduler.load_state_dict(scheduler_state) schedulers_out["lr_scheduler"] = scheduler elif scheduler_type == "bn_scheduler": scheduler = instantiate_bn_scheduler( model, scheduler_opt) scheduler.load_state_dict(scheduler_state) schedulers_out["bn_scheduler"] = scheduler else: raise NotImplementedError return schedulers_out except: log.warn("The checkpoint doesn t contain schedulers") return None
def load_optim_sched(self, model, load_state=True): if not self.is_empty: # initialize optimizer optimizer_config = self.optimizer optimizer_cls = getattr(torch.optim, optimizer_config[0]) optimizer_params = OmegaConf.create( self.run_config).training.optim.optimizer.params model.optimizer = optimizer_cls(model.parameters(), **optimizer_params) # initialize & load schedulersr schedulers_out = {} schedulers_config = self.schedulers for scheduler_type, (scheduler_opt, scheduler_state) in schedulers_config.items(): if scheduler_type == "lr_scheduler": optimizer = model.optimizer scheduler = instantiate_scheduler( optimizer, OmegaConf.create(scheduler_opt)) if load_state: scheduler.load_state_dict(scheduler_state) schedulers_out["lr_scheduler"] = scheduler elif scheduler_type == "bn_scheduler": scheduler = instantiate_bn_scheduler( model, OmegaConf.create(scheduler_opt)) if load_state: scheduler.load_state_dict(scheduler_state) schedulers_out["bn_scheduler"] = scheduler # load optimizer model.schedulers = schedulers_out if load_state: model.optimizer.load_state_dict(optimizer_config[1])
def instantiate_optimizers(self, config): # Optimiser optimizer_opt = self.get_from_opt( config, ["training", "optim", "optimizer"], msg_err="optimizer needs to be defined within the training config", ) optmizer_cls_name = optimizer_opt.get("class") optimizer_cls = getattr(torch.optim, optmizer_cls_name) optimizer_params = {} if hasattr(optimizer_opt, "params"): optimizer_params = optimizer_opt.params self._optimizer = optimizer_cls(self.parameters(), **optimizer_params) # LR Scheduler scheduler_opt = self.get_from_opt( config, ["training", "optim", "lr_scheduler"]) if scheduler_opt: update_lr_scheduler_on = config.update_lr_scheduler_on if update_lr_scheduler_on: self._update_lr_scheduler_on = update_lr_scheduler_on scheduler_opt.update_scheduler_on = self._update_lr_scheduler_on lr_scheduler = instantiate_scheduler(self._optimizer, scheduler_opt) self._add_scheduler("lr_scheduler", lr_scheduler) # BN Scheduler bn_scheduler_opt = self.get_from_opt( config, ["training", "optim", "bn_scheduler"]) if bn_scheduler_opt: update_bn_scheduler_on = config.update_bn_scheduler_on if update_bn_scheduler_on: self._update_bn_scheduler_on = update_bn_scheduler_on bn_scheduler_opt.update_scheduler_on = self._update_bn_scheduler_on bn_scheduler = instantiate_bn_scheduler(self, bn_scheduler_opt) self._add_scheduler("bn_scheduler", bn_scheduler) # Accumulated gradients self._accumulated_gradient_step = self.get_from_opt( config, ["training", "optim", "accumulated_gradient"]) if self._accumulated_gradient_step: if self._accumulated_gradient_step > 1: self._accumulated_gradient_count = 0 else: raise Exception( "When set, accumulated_gradient option should be an integer greater than 1" ) # Gradient clipping self._grad_clip = self.get_from_opt(config, ["training", "optim", "grad_clip"], default_value=-1)
def instantiate_optimizers(self, config, cuda_enabled=False): # Optimiser optimizer_opt = self.get_from_opt( config, ["training", "optim", "optimizer"], msg_err="optimizer needs to be defined within the training config", ) optmizer_cls_name = optimizer_opt.get("class") optimizer_cls = getattr(torch.optim, optmizer_cls_name) optimizer_params = {} if hasattr(optimizer_opt, "params"): optimizer_params = optimizer_opt.params self._optimizer = optimizer_cls(self.parameters(), **optimizer_params) # LR Scheduler scheduler_opt = self.get_from_opt( config, ["training", "optim", "lr_scheduler"]) if scheduler_opt: update_lr_scheduler_on = config.get( 'update_lr_scheduler_on') # Update to OmegaConf 2.0 if update_lr_scheduler_on: self._update_lr_scheduler_on = update_lr_scheduler_on scheduler_opt.update_scheduler_on = self._update_lr_scheduler_on lr_scheduler = instantiate_scheduler(self._optimizer, scheduler_opt) self._add_scheduler("lr_scheduler", lr_scheduler) # BN Scheduler bn_scheduler_opt = self.get_from_opt( config, ["training", "optim", "bn_scheduler"]) if bn_scheduler_opt: update_bn_scheduler_on = config.get( 'update_bn_scheduler_on') # update to OmegaConf 2.0 if update_bn_scheduler_on: self._update_bn_scheduler_on = update_bn_scheduler_on bn_scheduler_opt.update_scheduler_on = self._update_bn_scheduler_on bn_scheduler = instantiate_bn_scheduler(self, bn_scheduler_opt) self._add_scheduler("bn_scheduler", bn_scheduler) # Accumulated gradients self._accumulated_gradient_step = self.get_from_opt( config, ["training", "optim", "accumulated_gradient"]) if self._accumulated_gradient_step: if self._accumulated_gradient_step > 1: self._accumulated_gradient_count = 0 else: raise Exception( "When set, accumulated_gradient option should be an integer greater than 1" ) # Gradient clipping self._grad_clip = self.get_from_opt(config, ["training", "optim", "grad_clip"], default_value=-1) # Gradient Scaling self._enable_mixed = self.get_from_opt(config, ["training", "enable_mixed"], default_value=False) self._enable_mixed = bool(self._enable_mixed) if self.is_mixed_precision() and not cuda_enabled: log.warning( "Mixed precision is not supported on this device, using default precision..." ) self._enable_mixed = False elif self._enable_mixed and not self._supports_mixed: log.warning( "Mixed precision is not supported on this model, using default precision..." ) elif self.is_mixed_precision(): log.info("Model will use mixed precision") self._grad_scale = torch.cuda.amp.GradScaler( enabled=self.is_mixed_precision())