def test_step(self, *args, **kwargs) -> dict: self.eval() outputs = self._test_step(*args, **kwargs) assert is_dict(outputs), "Output of _test_step should be a dict" # Hopefully avoid any memory leak on gpu outputs = detach_dict(outputs) outputs = to_device(outputs, 'cpu') return outputs
def training_step(self, batch, *args, **kwargs) -> dict: self.train() outputs = self._training_step(batch, *args, **kwargs) assert is_dict(outputs), "Output of _training_step should be a dict" # Hopefully avoid any memory leak on gpu outputs = detach_dict(outputs) outputs = to_device(outputs, 'cpu') self._train_step += 1 return outputs
def _check_optimizer_accum(self): """ Check if helper functions for `loss.backward()`, `optim.step()` and `optim.zero_grad()` are used instead of the pytorch originals when gradient accumulation or clipping is used. """ # If not already, convert into a dict for convenience optim_dict = (self.optimizer if is_dict(self.optimizer) else { 'optimizer': self.optimizer }) clip_val = self.args.grads_norm_clip.max_norm accum_batches = self.args.accum_batches if clip_val > 0 or accum_batches > 1: # -- Warning messages stuff warn_pattern = ("{1} has been called directly while using {0}, " "which will prevent it from working. Please use " "{2} instead.") used_methods = { 'grads clipping': clip_val > 0, 'grads accumulation': accum_batches > 1 } methods = "".join([ "%s and " % name for name, active in used_methods.items() if active ])[:-len(" and ")] # -- Check if optimizer.step() has been directly called steps_calls = [o.step.calls for o in optim_dict.values()] if max(steps_calls) > 0: self.console_log.warning( warn_pattern.format(methods, "optimizer.step()", "self.optimize_step(optimizer)")) # -- Check if optimizer.zero_grad() has been directly called zero_calls = [o.zero_grad.calls for o in optim_dict.values()] if max(zero_calls) > 0: self.console_log.warning( warn_pattern.format(methods, "optimizer.zero_grad()", "self.zero_grad_step(optimizer)")) # -- Check if model.zero_grad() has been directly called if self.zero_grad.calls > 0: self.console_log.warning( warn_pattern.format(methods, "self.zero_grad()", "self.zero_grad_step(optimizer)")) # -- Check if self.compute_grad_step has been called at least once if self.compute_grad_step.calls == 0: self.console_log.warning( warn_pattern.format(methods, "loss.backward()", "self.compute_grad_step(loss)"))
def call_schedulers_optimizers(self): schedulers = self._model.scheduler_optimizer if isinstance(schedulers, torch.optim.lr_scheduler._LRScheduler): schedulers.step() elif is_dict(schedulers): for _, scheduler in schedulers.items(): scheduler.step() else: raise ValueError("optimizers_schedulers should be a \ dict or a torch.optim.lr_scheduler._LRScheduler object")
def save_checkpoint(self, path=None, filename=None, is_best=False): if filename is None: filename = self._epoch if isinstance(filename, int): filename = self.checkpoints_format.format(filename) if path is None: path = self.checkpoints_dir if is_best: path = self.best_checkpoints_dir safe_mkdirs(path, exist_ok=True) try: filename = os.path.join(path, filename) current_state_dict = { 'global_step': self._global_step, 'epoch': self._epoch, 'best_epoch': self.best_epoch, 'beaten_epochs': self.beaten_epochs, 'best_epoch_score': self.best_epoch_score, 'best_stats': self.best_stats, 'model_state_dict': self._model.state_dict(), } # -- there might be more than one optimizer if is_dict(self._model.optimizer): optimizer_state_dict = {} for key, opt in self._model.optimizer.items(): optimizer_state_dict.update({key: opt.state_dict()}) else: optimizer_state_dict = self._model.optimizer.state_dict() current_state_dict.update( {'optimizer_state_dict': optimizer_state_dict}) torch.save(current_state_dict, filename) # -- track filename to delete after, # if keep_only_last_checkpoint is set true if not is_best and 'init' not in filename: self.last_checkpoint = filename except Exception as e: self.console_log.error( "Error occurred while saving the checkpoint: %s", e) return filename
def _train(self): args = self.args if args.dry_run: print(args.extra_args.pretty()) self.stop() return {} # -- Training epoch # TODO: now labelled is hardcoded, make it general self._runner.train_epoch(self._runner.train_loader['labelled']) self.epoch = self._runner.epoch # -- Validate over all datasets val_outputs_flat = OrderedDict() for key_loader, val_loader in self._runner.val_loader.items(): num_batches = self._runner.num_batches_val[key_loader] # -- Validation on val_loader outputs = self._runner.validate(val_loader, num_batches=num_batches, set_name=key_loader, logger=self._runner.logger) # -- TODO: flatten_Dict should not be necessary, # -- prefix key is already concatenated in validate method # -- collect and return flatten metrics for key_stats, val_stats in outputs['stats'].items(): if is_dict(val_stats): if 'scalar' in val_stats.keys(): _flat = flatten_dict(val_stats['scalar'], False) elif 'scalars' in val_stats.keys(): _flat = flatten_dict(val_stats['scalars'], False) else: _flat = flatten_dict(val_stats, False) else: _flat = {key_stats: val_stats} val_outputs_flat.update(_flat) # -- Be sure that values are scalar and not tensor remove_keys = [] for key, val in val_outputs_flat.items(): if val.dim() == 0: val_outputs_flat[key] = self._get_scalar(val) else: remove_keys.append(key) for key in remove_keys: del val_outputs_flat[key] return val_outputs_flat
def log_each_epoch(self): # -- TODO: add anything else to be logged here # -- Log learning rates optimizers = self._model.optimizer if is_optimizer(optimizers): current_lr = optimizers.param_groups[0]['lr'] self._logger.log_metric('optim/lr', current_lr, self._epoch) elif is_dict(optimizers): for key, opt in optimizers.items(): current_lr = opt.param_groups[0]['lr'] self._logger.log_metric( 'optim/lr_{}'.format(key), current_lr, self._epoch) else: raise ValueError( "optimizer should be a dict or a torch.optim.Optimizer object")
def load_checkpoint(self, filename=None, is_best=False): """ This function actually restores a checkpoint with: - model state_dict - optimizers state_dict - global_step - epoch - best_epoch - beaten_epochs - best_epoch_score - best_stats """ path = self.checkpoints_dir ckp_format = self.checkpoints_format if is_best: path = self.best_checkpoints_dir if filename is None: filename = os.path.join(path, ckp_format.format(self._epoch)) elif isinstance(filename, int): filename = os.path.join(path, ckp_format.format(filename)) assert isinstance(filename, str), \ 'filename should be the epoch (int) or the checkpoint path (str)' checkpoint = torch.load(filename) self._global_step = checkpoint.get('global_step', checkpoint.get('seen', 0)) self._epoch = checkpoint['epoch'] self._model.best_epoch = checkpoint.get('best_epoch', -1) self._model.beaten_epochs = checkpoint.get('beaten_epochs', 0) self._model.best_epoch_score = checkpoint.get('best_epoch_score', 0) self._model.best_stats = checkpoint.get('best_stats', []) self._model.load_state_dict(checkpoint['model_state_dict']) if is_dict(self._model.optimizer): for key in self._model.optimizer.keys(): self._model.optimizer[key].load_state_dict( checkpoint['optimizer_state_dict'][key]) else: self._model.optimizer.load_state_dict( checkpoint['optimizer_state_dict'])
def _init_optimizer_accum(self): """ Initialize a dictionary of counters to keep track, for each optimizer, how many times the `optimizer.step()` function has been called. The optimizer itself is used as key of the dictionary. """ # If not already, convert into a dict for convenience optim_dict = (self.optimizer if is_dict(self.optimizer) else { 'optimizer': self.optimizer }) # -- Initialize the dict of steps per optimizer self._optimizers_accum = {key: 0 for key in optim_dict.values()} # -- Decorate optimizer's step() and zero_grad() with counters for optim in optim_dict.values(): optim.step = call_counter(optim.step) optim.zero_grad = call_counter(optim.zero_grad) # -- Decorate model zero_grad self.zero_grad = call_counter(self.zero_grad)
def restore_exp(self): # TODO: fix this checkpoint = self.reload_checkpoint( self.args.restore, os.path.join(self._restore_path, 'checkpoints'), self.args) self._global_step = checkpoint.get('global_step', checkpoint.get('seen', 0)) self._epoch = checkpoint['epoch'] self._model.best_epoch = checkpoint.get('best_epoch', -1) self._model.beaten_epochs = checkpoint.get('beaten_epochs', 0) self._model.best_epoch_score = checkpoint.get('best_epoch_score', 0) self._model.best_stats = checkpoint.get('best_stats', []) self._model.load_state_dict(checkpoint['model_state_dict']) if is_dict(self._model.optimizer): for key in self._model.optimizer.keys(): self._model.optimizer.load_state_dict( checkpoint['optimizer_state_dict'][key]) else: self._model.optimizer.load_state_dict( checkpoint['optimizer_state_dict'])
def override_with_custom_args(self, extra_args=None): """ Specific arguments can be overridden by: - `custom_config` file, defined via cli. - `extra_args` dict passed to the constructor of the Trainer object. - via command line using the dotted notation. The arguments should already defined in the default_config, otherwise an exception is raised since you are trying to modify an argument that does not exist. """ # -- From command line self._cli_args = OmegaConf.from_cli() # -- From experiment custom config file (passed from cli) self._custom_config_args = OmegaConf.create(dict()) if self._cli_args.custom_config is not None: self._custom_config_args = OmegaConf.load( self._cli_args.custom_config) # -- Extra config from Tune or any script if is_dict(extra_args): matching = [s for s in extra_args.keys() if "." in s] if len(matching) > 0: self.console_log.warning( "It seems you are using dotted notation \ in a dictionary! Please use a list instead, \ to modify the correct values! %s", matching) self._extra_args = OmegaConf.create(extra_args) elif is_list(extra_args): self._extra_args = OmegaConf.from_dotlist(extra_args) elif extra_args is None: self._extra_args = OmegaConf.create(dict()) else: raise ValueError("extra_args should be a list of \ dotted strings or a dict") # -- Save optimizer args for later dict_opt_custom = deepcopy(self._custom_config_args.optimizer) dict_opt_extra = deepcopy(self._extra_args.optimizer) if dict_opt_custom is not None: del self._custom_config_args['optimizer'] if dict_opt_extra is not None: del self._extra_args['optimizer'] # -- override custom args, ONLY IF THEY EXISTS self._args = OmegaConf.merge(self._args, self._custom_config_args, self._extra_args) # !!NOTE!! Optimizer could drastically change OmegaConf.set_struct(self._args, False) if dict_opt_custom is not None: self._args = OmegaConf.merge( self._args, OmegaConf.create({'optimizer': dict_opt_custom})) if dict_opt_extra is not None: self._args = OmegaConf.merge( self._args, OmegaConf.create({'optimizer': dict_opt_extra})) OmegaConf.set_struct(self._args, True) # !!NOTE!! WORKAROUND because of Tune comman line args OmegaConf.set_struct(self._args, False) self._args = OmegaConf.merge(self._args, self._cli_args) OmegaConf.set_struct(self._args, True) # -- Resolve interpolations to be sure all nodes are explicit self._args = OmegaConf.to_container(self._args, resolve=True) self._args = OmegaConf.create(self._args)